diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index c4afe5c3d9..a3679be205 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import urllib.parse
+from parameterized import parameterized
+
import synapse.rest.admin
from synapse.api.errors import Codes
from synapse.rest.client import login
@@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.other_user_device_id,
)
- def test_no_auth(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get a device of an user without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("PUT", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("DELETE", self.url, b"{}")
+ channel = self.make_request(method, self.url, b"{}")
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
channel = self.make_request(
- "GET",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "PUT",
- self.url,
- access_token=self.other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
self.url,
access_token=self.other_user_token,
)
@@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
@@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- def test_user_is_not_local(self):
+ @parameterized.expand(["GET", "PUT", "DELETE"])
+ def test_user_is_not_local(self, method: str):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
@@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "PUT",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
- body = json.dumps(update)
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content=update,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
Tests a normal successful update of display name
"""
# Set new display_name
- body = json.dumps({"display_name": "new displayname"})
channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"display_name": "new displayname"},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
"""
Tests that a remove of a device that does not exist returns 200.
"""
- body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": ["unknown_device1", "unknown_device2"]},
)
# Delete unknown devices returns status 200
@@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
device_ids.append(str(d["device_id"]))
# Delete devices
- body = json.dumps({"devices": device_ids})
channel = self.make_request(
"POST",
self.url,
access_token=self.admin_user_tok,
- content=body.encode(encoding="utf_8"),
+ content={"devices": device_ids},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index c9d4731017..40e032df7f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -29,123 +29,6 @@ from tests import unittest
"""Tests admin REST events for /rooms paths."""
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets_for_client_rest_resource,
- login.register_servlets,
- events.register_servlets,
- room.register_servlets,
- room.register_deprecated_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.event_creation_handler = hs.get_event_creation_handler()
- hs.config.user_consent_version = "1"
-
- consent_uri_builder = Mock()
- consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
- self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- self.other_user = self.register_user("user", "pass")
- self.other_user_token = self.login("user", "pass")
-
- # Mark the admin user as having consented
- self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
- def test_shutdown_room_consent(self):
- """Test that we can shutdown rooms with local users who have not
- yet accepted the privacy policy. This used to fail when we tried to
- force part the user from the old room.
- """
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Assert one user in room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([self.other_user], users_in_room)
-
- # Enable require consent to send events
- self.event_creation_handler._block_events_without_consent_error = "Error"
-
- # Assert that the user is getting consent error
- self.helper.send(
- room_id, body="foo", tok=self.other_user_token, expect_code=403
- )
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert there is now no longer anyone in the room
- users_in_room = self.get_success(self.store.get_users_in_room(room_id))
- self.assertEqual([], users_in_room)
-
- def test_shutdown_room_block_peek(self):
- """Test that a world_readable room can no longer be peeked into after
- it has been shut down.
- """
-
- self.event_creation_handler._block_events_without_consent_error = None
-
- room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
- # Enable world readable
- url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
- channel = self.make_request(
- "PUT",
- url.encode("ascii"),
- json.dumps({"history_visibility": "world_readable"}),
- access_token=self.other_user_token,
- )
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the admin can still send shutdown
- url = "/_synapse/admin/v1/shutdown_room/" + room_id
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- json.dumps({"new_room_user_id": self.admin_user}),
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Assert we can no longer peek into the room
- self._assert_peek(room_id, expect_code=403)
-
- def _assert_peek(self, room_id, expect_code):
- """Assert that the admin user can (or cannot) peek into the room."""
-
- url = "rooms/%s/initialSync" % (room_id,)
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
- url = "events?timeout=0&room_id=" + room_id
- channel = self.make_request(
- "GET", url.encode("ascii"), access_token=self.admin_user_tok
- )
- self.assertEqual(
- expect_code, int(channel.result["code"]), msg=channel.result["body"]
- )
-
-
@parameterized_class(
("method", "url_template"),
[
@@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
)
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
- """Test /purge_room admin API."""
-
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- room.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.store = hs.get_datastore()
-
- self.admin_user = self.register_user("admin", "pass", admin=True)
- self.admin_user_tok = self.login("admin", "pass")
-
- def test_purge_room(self):
- room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
- # All users have to have left the room.
- self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
- url = "/_synapse/admin/v1/purge_room"
- channel = self.make_request(
- "POST",
- url.encode("ascii"),
- {"room_id": room_id},
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
- # Test that the following tables have been purged of all rows related to the room.
- for table in PURGE_TABLES:
- count = self.get_success(
- self.store.db_pool.simple_select_one_onecol(
- table=table,
- keyvalues={"room_id": room_id},
- retcol="COUNT(*)",
- desc="test_purge_room",
- )
- )
-
- self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
-
-
class RoomTestCase(unittest.HomeserverTestCase):
"""Test /room admin API."""
diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py
new file mode 100644
index 0000000000..fbceba3254
--- /dev/null
+++ b/tests/rest/admin/test_server_notice.py
@@ -0,0 +1,450 @@
+# Copyright 2021 Dirk Klimpel
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login, room, sync
+from synapse.storage.roommember import RoomsForUser
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.unittest import override_config
+
+
+class ServerNoticeTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ sync.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.room_shutdown_handler = hs.get_room_shutdown_handler()
+ self.pagination_handler = hs.get_pagination_handler()
+ self.server_notices_manager = self.hs.get_server_notices_manager()
+
+ # Create user
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_token = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/send_server_notice"
+
+ def test_no_auth(self):
+ """Try to send a server notice without authentication."""
+ channel = self.make_request("POST", self.url)
+
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_requester_is_no_admin(self):
+ """If the user is not a server admin, an error is returned."""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.other_user_token,
+ )
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_does_not_exist(self):
+ """Tests that a lookup for a user that does not exist returns a 404"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": "@unknown_person:test", "content": ""},
+ )
+
+ self.assertEqual(404, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_user_is_not_local(self):
+ """
+ Tests that a lookup for a user that is not a local returns a 400
+ """
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": "@unknown_person:unknown_domain",
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(
+ "Server notices can only be sent to local users", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_invalid_parameter(self):
+ """If parameters are invalid, an error is returned."""
+
+ # no content, no user
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.NOT_JSON, channel.json_body["errcode"])
+
+ # no content
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"])
+
+ # no body
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": ""},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'body' not in content", channel.json_body["error"])
+
+ # no msgtype
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={"user_id": self.other_user, "content": {"body": ""}},
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual("'msgtype' not in content", channel.json_body["error"])
+
+ def test_server_notice_disabled(self):
+ """Tests that server returns error if server notice is disabled"""
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": "",
+ },
+ )
+
+ self.assertEqual(400, channel.code, msg=channel.json_body)
+ self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
+ self.assertEqual(
+ "Server notices are not enabled on this server", channel.json_body["error"]
+ )
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice(self):
+ """
+ Tests that sending two server notices is successfully,
+ the server uses the same room and do not send messages twice.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(room=room_id, user=self.other_user, tok=self.other_user_token)
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # invalidate cache of server notices room_ids
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has no new invites or memberships
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 2)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ self.assertEqual(messages[1]["content"]["body"], "test msg two")
+ self.assertEqual(messages[1]["sender"], "@notices:test")
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_leave_room(self):
+ """
+ Tests that sending a server notices is successfully.
+ The user leaves the room and the second message appears
+ in a new room.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # user leaves the romm
+ self.helper.leave(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id the user gets the message
+ # in old room
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # room has the same id
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ @override_config({"server_notices": {"system_mxid_localpart": "notices"}})
+ def test_send_server_notice_delete_room(self):
+ """
+ Tests that the user get server notice in a new room
+ after the first server notice room was deleted.
+ """
+ # user has no room memberships
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # send first message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg one"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ first_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=first_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get messages
+ messages = self._sync_and_get_messages(first_room_id, self.other_user_token)
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg one")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+
+ # shut down and purge room
+ self.get_success(
+ self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user)
+ )
+ self.get_success(self.pagination_handler.purge_room(first_room_id))
+
+ # user is not member anymore
+ self._check_invite_and_join_status(self.other_user, 0, 0)
+
+ # It doesn't really matter what API we use here, we just want to assert
+ # that the room doesn't exist.
+ summary = self.get_success(self.store.get_room_summary(first_room_id))
+ # The summary should be empty since the room doesn't exist.
+ self.assertEqual(summary, {})
+
+ # invalidate cache of server notices room_ids
+ # if server tries to send to a cached room_id it gives an error
+ self.get_success(
+ self.server_notices_manager.get_or_create_notice_room_for_user.invalidate_all()
+ )
+
+ # send second message
+ channel = self.make_request(
+ "POST",
+ self.url,
+ access_token=self.admin_user_tok,
+ content={
+ "user_id": self.other_user,
+ "content": {"msgtype": "m.text", "body": "test msg two"},
+ },
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+
+ # user has one invite
+ invited_rooms = self._check_invite_and_join_status(self.other_user, 1, 0)
+ second_room_id = invited_rooms[0].room_id
+
+ # user joins the room and is member now
+ self.helper.join(
+ room=second_room_id, user=self.other_user, tok=self.other_user_token
+ )
+ self._check_invite_and_join_status(self.other_user, 0, 1)
+
+ # get message
+ messages = self._sync_and_get_messages(second_room_id, self.other_user_token)
+
+ self.assertEqual(len(messages), 1)
+ self.assertEqual(messages[0]["content"]["body"], "test msg two")
+ self.assertEqual(messages[0]["sender"], "@notices:test")
+ # second room has new ID
+ self.assertNotEqual(first_room_id, second_room_id)
+
+ def _check_invite_and_join_status(
+ self, user_id: str, expected_invites: int, expected_memberships: int
+ ) -> RoomsForUser:
+ """Check invite and room membership status of a user.
+
+ Args
+ user_id: user to check
+ expected_invites: number of expected invites of this user
+ expected_memberships: number of expected room memberships of this user
+ Returns
+ room_ids from the rooms that the user is invited
+ """
+
+ invited_rooms = self.get_success(
+ self.store.get_invited_rooms_for_local_user(user_id)
+ )
+ self.assertEqual(expected_invites, len(invited_rooms))
+
+ room_ids = self.get_success(self.store.get_rooms_for_user(user_id))
+ self.assertEqual(expected_memberships, len(room_ids))
+
+ return invited_rooms
+
+ def _sync_and_get_messages(self, room_id: str, token: str) -> List[JsonDict]:
+ """
+ Do a sync and get messages of a room.
+
+ Args
+ room_id: room that contains the messages
+ token: access token of user
+
+ Returns
+ list of messages contained in the room
+ """
+ channel = self.make_request(
+ "GET", "/_matrix/client/r0/sync", access_token=token
+ )
+ self.assertEqual(channel.code, 200)
+
+ # Get the messages
+ room = channel.json_body["rooms"]["join"][room_id]
+ messages = [
+ x for x in room["timeline"]["events"] if x["type"] == "m.room.message"
+ ]
+ return messages
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef77275238..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("Bob's name", channel.json_body["displayname"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual(1, len(channel.json_body["threepids"]))
self.assertEqual(
"external_id1", channel.json_body["external_ids"][0]["external_id"]
)
self.assertEqual(
"auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
)
+ self.assertEqual(1, len(channel.json_body["external_ids"]))
self.assertFalse(channel.json_body["admin"])
self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
Test setting threepid for an other user.
"""
- # Delete old and add new threepid to user
+ # Add two threepids to user
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
- content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob1@bob.bob"},
+ {"medium": "email", "address": "bob2@bob.bob"},
+ ],
+ },
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
+ # result does not always have the same sort order, therefore it becomes sorted
+ sorted_result = sorted(
+ channel.json_body["threepids"], key=lambda k: k["address"]
+ )
+ self.assertEqual("email", sorted_result[0]["medium"])
+ self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+ self.assertEqual("email", sorted_result[1]["medium"])
+ self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Set a new and remove a threepid
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={
+ "threepids": [
+ {"medium": "email", "address": "bob2@bob.bob"},
+ {"medium": "email", "address": "bob3@bob.bob"},
+ ],
+ },
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
# Get user
channel = self.make_request(
@@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["threepids"]))
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
- self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+ self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+ self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+ self._check_fields(channel.json_body)
+
+ # Remove threepids
+ channel = self.make_request(
+ "PUT",
+ self.url_other_user,
+ access_token=self.admin_user_tok,
+ content={"threepids": []},
+ )
+ self.assertEqual(200, channel.code, msg=channel.json_body)
+ self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(0, len(channel.json_body["threepids"]))
+ self._check_fields(channel.json_body)
def test_set_external_id(self):
"""
@@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
+ self.assertEqual(2, len(channel.json_body["external_ids"]))
self.assertEqual(
channel.json_body["external_ids"],
[
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..b946fca8b3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/test_account.py
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
index cf5cfb910c..e2fcbdc63a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -25,7 +25,7 @@ from synapse.types import JsonDict, UserID
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
from tests.unittest import override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 13b3c5f499..422361b62a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
self.url = b"/_matrix/client/r0/capabilities"
hs = self.setup_test_homeserver()
- self.store = hs.get_datastore()
self.config = hs.config
self.auth_handler = hs.get_auth_handler()
return hs
+ def prepare(self, reactor, clock, hs):
+ self.localpart = "user"
+ self.password = "pass"
+ self.user = self.register_user(self.localpart, self.password)
+
def test_check_auth_required(self):
channel = self.make_request("GET", self.url)
self.assertEqual(channel.code, 401)
def test_get_room_version_capabilities(self):
- self.register_user("user", "pass")
- access_token = self.login("user", "pass")
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
def test_get_change_password_capabilities_password_login(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
- access_token = self.login(user, password)
+ access_token = self.login(self.localpart, self.password)
channel = self.make_request("GET", self.url, access_token=access_token)
capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"localdb_enabled": False}})
def test_get_change_password_capabilities_localdb_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
@override_config({"password_config": {"enabled": False}})
def test_get_change_password_capabilities_password_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -102,14 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertFalse(capabilities["m.change_password"]["enabled"])
+ def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+ """Test that per default msc3283 is disabled server returns `m.change_password`."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+ self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+ self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+ @override_config({"experimental_features": {"msc3283_enabled": True}})
+ def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+ """Test if msc3283 is enabled server returns capabilities."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertTrue(capabilities["m.change_password"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+ self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_displayname": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_displayname_capabilities_displayname_disabled(self):
+ """Test if set displayname is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+ @override_config(
+ {
+ "enable_set_avatar_url": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+ """Test if set avatar_url is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+ @override_config(
+ {
+ "enable_3pid_changes": False,
+ "experimental_features": {"msc3283_enabled": True},
+ }
+ )
+ def test_change_3pid_capabilities_3pid_disabled(self):
+ """Test if change 3pid is disabled that the server responds it."""
+ access_token = self.login(self.localpart, self.password)
+
+ channel = self.make_request("GET", self.url, access_token=access_token)
+ capabilities = channel.json_body["capabilities"]
+
+ self.assertEqual(channel.code, 200)
+ self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
@override_config({"experimental_features": {"msc3244_enabled": False}})
def test_get_does_not_include_msc3244_fields_when_disabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
@@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
)
def test_get_does_include_msc3244_fields_when_enabled(self):
- localpart = "user"
- password = "pass"
- user = self.register_user(localpart, password)
access_token = self.get_success(
self.auth_handler.get_access_token_for_user_id(
- user, device_id=None, valid_until_ms=None
+ self.user, device_id=None, valid_until_ms=None
)
)
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..d2181ea907 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/test_directory.py
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
index a90294003e..a90294003e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/test_events.py
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..475c6bed3d 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/test_filter.py
diff --git a/tests/rest/client/v2_alpha/test_groups.py b/tests/rest/client/test_groups.py
index ad0425ae65..ad0425ae65 100644
--- a/tests/rest/client/v2_alpha/test_groups.py
+++ b/tests/rest/client/test_groups.py
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 0000000000..d7fa635eae
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ keys.register_servlets,
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ ]
+
+ def test_rejects_device_id_ice_key_outside_of_list(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: "device_id1",
+ },
+ },
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_rejects_device_key_given_as_map_to_bool(self):
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ bob = self.register_user("bob", "uncle")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ bob: {
+ "device_id1": True,
+ },
+ },
+ },
+ alice_token,
+ )
+
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
+
+ def test_requires_device_key(self):
+ """`device_keys` is required. We should complain if it's missing."""
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {},
+ alice_token,
+ )
+ self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+ self.assertEqual(
+ channel.json_body["errcode"],
+ Codes.BAD_JSON,
+ channel.result,
+ )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
index eba3552b19..5b2243fe52 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -32,7 +32,7 @@ from synapse.types import create_requester
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3cf5871899..3cf5871899 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
diff --git a/tests/rest/client/test_power_levels.py b/tests/rest/client/test_power_levels.py
index 91d0762cb0..c0de4c93a8 100644
--- a/tests/rest/client/test_power_levels.py
+++ b/tests/rest/client/test_power_levels.py
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.errors import Codes
+from synapse.events.utils import CANONICALJSON_MAX_INT, CANONICALJSON_MIN_INT
from synapse.rest import admin
from synapse.rest.client import login, room, sync
@@ -203,3 +205,79 @@ class PowerLevelsTestCase(HomeserverTestCase):
tok=self.admin_access_token,
expect_code=200, # expect success
)
+
+ def test_cannot_set_string_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL "0"
+ room_power_levels["users"].update({self.user_user_id: "0"})
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_large_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL above the max safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MAX_INT + 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
+
+ def test_cannot_set_unsafe_small_power_levels(self):
+ room_power_levels = self.helper.get_state(
+ self.room_id,
+ "m.room.power_levels",
+ tok=self.admin_access_token,
+ )
+
+ # Update existing power levels with user at PL below the minimum safe integer
+ room_power_levels["users"].update(
+ {self.user_user_id: CANONICALJSON_MIN_INT - 1}
+ )
+
+ body = self.helper.send_state(
+ self.room_id,
+ "m.room.power_levels",
+ room_power_levels,
+ tok=self.admin_access_token,
+ expect_code=400, # expect failure
+ )
+
+ self.assertEqual(
+ body["errcode"],
+ Codes.BAD_JSON,
+ body,
+ )
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..1d152352d1 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/test_presence.py
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..2860579c2e 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/test_profile.py
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d0ce91ccd9..d0ce91ccd9 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..02b5e9a8d0 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/test_relations.py
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
index ee6b0b9ebf..ee6b0b9ebf 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
index 50100a5ae4..50100a5ae4 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 6db7062a8e..6db7062a8e 100644
--- a/tests/rest/client/v2_alpha/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..283eccd53f 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..95be369d4b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/test_sync.py
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..b54b004733 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/test_typing.py
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 72f976d8e2..72f976d8e2 100644
--- a/tests/rest/client/v2_alpha/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..954ad1a1fd 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/utils.py
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed..0000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/rest/client/v2_alpha/__init__.py
+++ /dev/null
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index 6085444b9d..2f7eebfe69 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -21,7 +21,7 @@ from unittest.mock import Mock
from urllib import parse
import attr
-from parameterized import parameterized_class
+from parameterized import parameterized, parameterized_class
from PIL import Image as Image
from twisted.internet import defer
@@ -473,6 +473,43 @@ class MediaRepoTests(unittest.HomeserverTestCase):
},
)
+ @parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
+ def test_same_quality(self, method, desired_size):
+ """Test that choosing between thumbnails with the same quality rating succeeds.
+
+ We are not particular about which thumbnail is chosen."""
+ self.assertIsNotNone(
+ self.thumbnail_resource._select_thumbnail(
+ desired_width=desired_size,
+ desired_height=desired_size,
+ desired_method=method,
+ desired_type=self.test_image.content_type,
+ # Provide two identical thumbnails which are guaranteed to have the same
+ # quality rating.
+ thumbnail_infos=[
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail1{self.test_image.extension}",
+ },
+ {
+ "thumbnail_width": 32,
+ "thumbnail_height": 32,
+ "thumbnail_method": method,
+ "thumbnail_type": self.test_image.content_type,
+ "thumbnail_length": 256,
+ "filesystem_id": f"thumbnail2{self.test_image.extension}",
+ },
+ ],
+ file_id=f"image{self.test_image.extension}",
+ url_cache=None,
+ server_name=None,
+ )
+ )
+
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|