summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-01-05 13:33:28 +0000
committerGitHub <noreply@github.com>2022-01-05 13:33:28 +0000
commit88a78c6577086527e4569541b09e437a1ca0d1a9 (patch)
tree18e4072e9ce696bb45553df3f8988cd03ac1d737 /tests
parentRefactor the way we set `outlier` (#11634) (diff)
downloadsynapse-88a78c6577086527e4569541b09e437a1ca0d1a9.tar.xz
Cache empty responses from `/user/devices` (#11587)
If we've never made a request to a remote homeserver, we should cache the response---even if the response is "this user has no devices".
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_e2e_keys.py96
-rw-r--r--tests/test_utils/__init__.py4
2 files changed, 98 insertions, 2 deletions
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee348..734ed84d78 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Iterable
 from unittest import mock
 
+from parameterized import parameterized
 from signedjson import key as key, sign as sign
 
 from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         remote_user_id = "@test:other"
         local_user_id = "@test:test"
 
+        # Pretend we're sharing a room with the user we're querying. If not,
+        # `_query_devices_for_destination` will return early.
         self.store.get_rooms_for_user = mock.Mock(
             return_value=defer.succeed({"some_room_id"})
         )
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
                 }
             },
         )
+
+    @parameterized.expand(
+        [
+            # The remote homeserver's response indicates that this user has 0/1/2 devices.
+            ([],),
+            (["device_1"],),
+            (["device_1", "device_2"],),
+        ]
+    )
+    def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+        """Test that requests for all of a remote user's devices are cached.
+
+        We do this by asserting that only one call over federation was made, and that
+        the two queries to the local homeserver produce the same response.
+        """
+        local_user_id = "@test:test"
+        remote_user_id = "@test:other"
+        request_body = {"device_keys": {remote_user_id: []}}
+
+        response_devices = [
+            {
+                "device_id": device_id,
+                "keys": {
+                    "algorithms": ["dummy"],
+                    "device_id": device_id,
+                    "keys": {f"dummy:{device_id}": "dummy"},
+                    "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+                    "unsigned": {},
+                    "user_id": "@test:other",
+                },
+            }
+            for device_id in device_ids
+        ]
+
+        response_body = {
+            "devices": response_devices,
+            "user_id": remote_user_id,
+            "stream_id": 12345,  # an integer, according to the spec
+        }
+
+        e2e_handler = self.hs.get_e2e_keys_handler()
+
+        # Pretend we're sharing a room with the user we're querying. If not,
+        # `_query_devices_for_destination` will return early.
+        mock_get_rooms = mock.patch.object(
+            self.store,
+            "get_rooms_for_user",
+            new_callable=mock.MagicMock,
+            return_value=make_awaitable(["some_room_id"]),
+        )
+        mock_request = mock.patch.object(
+            self.hs.get_federation_client(),
+            "query_user_devices",
+            new_callable=mock.MagicMock,
+            return_value=make_awaitable(response_body),
+        )
+
+        with mock_get_rooms, mock_request as mocked_federation_request:
+            # Make the first query and sanity check it succeeds.
+            response_1 = self.get_success(
+                e2e_handler.query_devices(
+                    request_body,
+                    timeout=10,
+                    from_user_id=local_user_id,
+                    from_device_id="some_device_id",
+                )
+            )
+            self.assertEqual(response_1["failures"], {})
+
+            # We should have made a federation request to do so.
+            mocked_federation_request.assert_called_once()
+
+            # Reset the mock so we can prove we don't make a second federation request.
+            mocked_federation_request.reset_mock()
+
+            # Repeat the query.
+            response_2 = self.get_success(
+                e2e_handler.query_devices(
+                    request_body,
+                    timeout=10,
+                    from_user_id=local_user_id,
+                    from_device_id="some_device_id",
+                )
+            )
+            self.assertEqual(response_2["failures"], {})
+
+            # We should not have made a second federation request.
+            mocked_federation_request.assert_not_called()
+
+            # The two requests to the local homeserver should be identical.
+            self.assertEqual(response_1, response_2)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 15ac2bfeba..f05a373aa0 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -19,7 +19,7 @@ import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Any, Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, TypeVar
 from unittest.mock import Mock
 
 import attr
@@ -46,7 +46,7 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
     raise Exception("awaitable has not yet completed")
 
 
-def make_awaitable(result: Any) -> Awaitable[Any]:
+def make_awaitable(result: TV) -> Awaitable[TV]:
     """
     Makes an awaitable, suitable for mocking an `async` function.
     This uses Futures as they can be awaited multiple times so can be returned