diff --git a/tests/rest/client/v2_alpha/test_keys.py b/tests/rest/client/v2_alpha/test_keys.py
new file mode 100644
index 0000000000..80a4e728ff
--- /dev/null
+++ b/tests/rest/client/v2_alpha/test_keys.py
@@ -0,0 +1,77 @@
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ keys.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def test_rejects_device_id_ice_key_outside_of_list(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: "device_id1",
+ },
+ },
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_rejects_device_key_given_as_map_to_bool(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: {
+ "device_id1": True,
+ },
+ },
+ },
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_requires_device_key(self):
+ """`device_keys` is required. We should complain if it's missing."""
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {},
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
|