diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.admin_user = self.register_user("admin", "pass", admin=True)
+ self.admin_user_tok = self.login("admin", "pass")
+
+ self.other_user = self.register_user("user", "pass")
+ self.other_user_tok = self.login("user", "pass")
+
+ self.url = "/_synapse/admin/v1/registration_tokens"
+
+ def _new_token(self, **kwargs):
+ """Helper function to create a token."""
+ token = kwargs.get(
+ "token",
+ "".join(random.choices(string.ascii_letters, k=8)),
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": kwargs.get("uses_allowed", None),
+ "pending": kwargs.get("pending", 0),
+ "completed": kwargs.get("completed", 0),
+ "expiry_time": kwargs.get("expiry_time", None),
+ },
+ )
+ )
+ return token
+
+ # CREATION
+
+ def test_create_no_auth(self):
+ """Try to create a token without authentication."""
+ channel = self.make_request("POST", self.url + "/new", {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_create_requester_not_admin(self):
+ """Try to create a token while not an admin."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_create_using_defaults(self):
+ """Create a token using all the defaults."""
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_specifying_fields(self):
+ """Create a token specifying the value of all fields."""
+ data = {
+ "token": "abcd",
+ "uses_allowed": 1,
+ "expiry_time": self.clock.time_msec() + 1000000,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], "abcd")
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_with_null_value(self):
+ """Create a token specifying unlimited uses and no expiry."""
+ data = {
+ "uses_allowed": None,
+ "expiry_time": None,
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 16)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ def test_create_token_too_long(self):
+ """Check token longer than 64 chars is invalid."""
+ data = {"token": "a" * 65}
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_invalid_chars(self):
+ """Check you can't create token with invalid characters."""
+ data = {
+ "token": "abc/def",
+ }
+
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_token_already_exists(self):
+ """Check you can't create token that already exists."""
+ data = {
+ "token": "abcd",
+ }
+
+ channel1 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+ channel2 = self.make_request(
+ "POST",
+ self.url + "/new",
+ data,
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+ self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_unable_to_generate_token(self):
+ """Check right error is raised when server can't generate unique token."""
+ # Create all possible single character tokens
+ tokens = []
+ for c in string.ascii_letters + string.digits + "-_":
+ tokens.append(
+ {
+ "token": c,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ }
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert_many(
+ "registration_tokens",
+ tokens,
+ "create_all_registration_tokens",
+ )
+ )
+
+ # Check creating a single character token fails with a 500 status code
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+ def test_create_uses_allowed(self):
+ """Check you can only create a token with good values for uses_allowed."""
+ # Should work with 0 (token is invalid from the start)
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+ # Should fail with negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_expiry_time(self):
+ """Check you can't create a token with an invalid expiry_time."""
+ # Should fail with a time in the past
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() - 10000},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"expiry_time": self.clock.time_msec() + 1000000.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_create_length(self):
+ """Check you can only generate a token with a valid length."""
+ # Should work with 64
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 64},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["token"]), 64)
+
+ # Should fail with 0
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a float
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 8.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with 65
+ channel = self.make_request(
+ "POST",
+ self.url + "/new",
+ {"length": 65},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # UPDATING
+
+ def test_update_no_auth(self):
+ """Try to update a token without authentication."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_update_requester_not_admin(self):
+ """Try to update a token while not an admin."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_update_non_existent(self):
+ """Try to update a token that doesn't exist."""
+ channel = self.make_request(
+ "PUT",
+ self.url + "/1234",
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_update_uses_allowed(self):
+ """Test updating just uses_allowed."""
+ # Create new token using default values
+ token = self._new_token()
+
+ # Should succeed with 1
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with 0 (makes token invalid)
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 0},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 0)
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+
+ # Should fail with a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": 1.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail with a negative integer
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"uses_allowed": -5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_expiry_time(self):
+ """Test updating just expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ # Should succeed with a time in the future
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should succeed with null
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": None},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertIsNone(channel.json_body["uses_allowed"])
+
+ # Should fail with a time in the past
+ past_time = self.clock.time_msec() - 10000
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": past_time},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # Should fail a float
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ {"expiry_time": new_expiry_time + 0.5},
+ access_token=self.admin_user_tok,
+ )
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ def test_update_both(self):
+ """Test updating both uses_allowed and expiry_time."""
+ # Create new token using default values
+ token = self._new_token()
+ new_expiry_time = self.clock.time_msec() + 1000000
+
+ data = {
+ "uses_allowed": 1,
+ "expiry_time": new_expiry_time,
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["uses_allowed"], 1)
+ self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+ def test_update_invalid_type(self):
+ """Test using invalid types doesn't work."""
+ # Create new token using default values
+ token = self._new_token()
+
+ data = {
+ "uses_allowed": False,
+ "expiry_time": "1626430124000",
+ }
+
+ channel = self.make_request(
+ "PUT",
+ self.url + "/" + token,
+ data,
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+ # DELETING
+
+ def test_delete_no_auth(self):
+ """Try to delete a token without authentication."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_delete_requester_not_admin(self):
+ """Try to delete a token while not an admin."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_delete_non_existent(self):
+ """Try to delete a token that doesn't exist."""
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_delete(self):
+ """Test deleting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "DELETE",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+ # GETTING ONE
+
+ def test_get_no_auth(self):
+ """Try to get a token without authentication."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ )
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_get_requester_not_admin(self):
+ """Try to get a token while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234", # Token doesn't exist but that doesn't matter
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_get_non_existent(self):
+ """Try to get a token that doesn't exist."""
+ channel = self.make_request(
+ "GET",
+ self.url + "/1234",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+ def test_get(self):
+ """Test getting a token."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url + "/" + token,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(channel.json_body["token"], token)
+ self.assertIsNone(channel.json_body["uses_allowed"])
+ self.assertIsNone(channel.json_body["expiry_time"])
+ self.assertEqual(channel.json_body["pending"], 0)
+ self.assertEqual(channel.json_body["completed"], 0)
+
+ # LISTING
+
+ def test_list_no_auth(self):
+ """Try to list tokens without authentication."""
+ channel = self.make_request("GET", self.url, {})
+ self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+ def test_list_requester_not_admin(self):
+ """Try to list tokens while not an admin."""
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.other_user_tok,
+ )
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+ def test_list_all(self):
+ """Test listing all tokens."""
+ # Create new token using default values
+ token = self._new_token()
+
+ channel = self.make_request(
+ "GET",
+ self.url,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+ token_info = channel.json_body["registration_tokens"][0]
+ self.assertEqual(token_info["token"], token)
+ self.assertIsNone(token_info["uses_allowed"])
+ self.assertIsNone(token_info["expiry_time"])
+ self.assertEqual(token_info["pending"], 0)
+ self.assertEqual(token_info["completed"], 0)
+
+ def test_list_invalid_query_parameter(self):
+ """Test with `valid` query parameter not `true` or `false`."""
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=x",
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+ def _test_list_query_parameter(self, valid: str):
+ """Helper used to test both valid=true and valid=false."""
+ # Create 2 valid and 2 invalid tokens.
+ now = self.hs.get_clock().time_msec()
+ # Create always valid token
+ valid1 = self._new_token()
+ # Create token that hasn't been used up
+ valid2 = self._new_token(uses_allowed=1)
+ # Create token that has expired
+ invalid1 = self._new_token(expiry_time=now - 10000)
+ # Create token that has been used up but hasn't expired
+ invalid2 = self._new_token(
+ uses_allowed=2,
+ pending=1,
+ completed=1,
+ expiry_time=now + 1000000,
+ )
+
+ if valid == "true":
+ tokens = [valid1, valid2]
+ else:
+ tokens = [invalid1, invalid2]
+
+ channel = self.make_request(
+ "GET",
+ self.url + "?valid=" + valid,
+ {},
+ access_token=self.admin_user_tok,
+ )
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+ token_info_1 = channel.json_body["registration_tokens"][0]
+ token_info_2 = channel.json_body["registration_tokens"][1]
+ self.assertIn(token_info_1["token"], tokens)
+ self.assertIn(token_info_2["token"], tokens)
+
+ def test_list_valid(self):
+ """Test listing just valid tokens."""
+ self._test_list_query_parameter(valid="true")
+
+ def test_list_invalid(self):
+ """Test listing just invalid tokens."""
+ self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
from synapse.api.errors import Codes
from synapse.appservice import ApplicationService
from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
from tests import unittest
from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_requires_token(self):
+ username = "kermit"
+ device_id = "frogfone"
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params = {
+ "username": username,
+ "password": "monkey",
+ "device_id": device_id,
+ }
+
+ # Request without auth to get flows and session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+ # Synapse adds a dummy stage to differentiate flows where otherwise one
+ # flow would be a subset of another flow.
+ self.assertCountEqual(
+ [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+ (f["stages"] for f in flows),
+ )
+ session = channel.json_body["session"]
+
+ # Do the registration token stage and check it has completed
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ # Do the m.login.dummy stage and check registration was successful
+ params["auth"] = {
+ "type": LoginType.DUMMY,
+ "session": session,
+ }
+ request_data = json.dumps(params)
+ channel = self.make_request(b"POST", self.url, request_data)
+ det_data = {
+ "user_id": f"@{username}:{self.hs.hostname}",
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+
+ # Check the `completed` counter has been incremented and pending is 0
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["completed"], 1)
+ self.assertEquals(res["pending"], 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_invalid(self):
+ params = {
+ "username": "kermit",
+ "password": "monkey",
+ }
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Test with token param missing (invalid)
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with non-string (invalid)
+ params["auth"]["token"] = 1234
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Test with unknown token (invalid)
+ params["auth"]["token"] = "1234"
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_limit_uses(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ # Create token that can be used once
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": 1,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ # Do 2 requests without auth to get two session IDs
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with session1 and check `pending` is 1
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Repeat request to make sure pending isn't increased again
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 1)
+
+ # Check auth fails when using token with session2
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+ # Check pending=0 and completed=1
+ res = self.get_success(
+ store.db_pool.simple_select_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcols=["pending", "completed"],
+ )
+ )
+ self.assertEquals(res["pending"], 0)
+ self.assertEquals(res["completed"], 1)
+
+ # Check auth still fails when using token with session2
+ channel = self.make_request(b"POST", self.url, json.dumps(params2))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_expiry(self):
+ token = "abcd"
+ now = self.hs.get_clock().time_msec()
+ store = self.hs.get_datastore()
+ # Create token that expired yesterday
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": now - 24 * 60 * 60 * 1000,
+ },
+ )
+ )
+ params = {"username": "kermit", "password": "monkey"}
+ # Request without auth to get session
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Check authentication fails with expired token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+ self.assertEquals(channel.json_body["completed"], [])
+
+ # Update token so it expires tomorrow
+ self.get_success(
+ store.db_pool.simple_update_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+ )
+ )
+
+ # Check authentication succeeds
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ completed = channel.json_body["completed"]
+ self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry(self):
+ """Test `pending` is decremented when an uncompleted session expires."""
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do 2 requests without auth to get two session IDs
+ params1 = {"username": "bert", "password": "monkey"}
+ params2 = {"username": "ernie", "password": "monkey"}
+ channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+ session1 = channel1.json_body["session"]
+ channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+ session2 = channel2.json_body["session"]
+
+ # Use token with both sessions
+ params1["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session1,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ params2["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session2,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params2))
+
+ # Complete registration with session1
+ params1["auth"]["type"] = LoginType.DUMMY
+ self.make_request(b"POST", self.url, json.dumps(params1))
+
+ # Check `result` of registration token stage for session1 is `True`
+ result1 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session1,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertTrue(db_to_json(result1))
+
+ # Check `result` for session2 is the token used
+ result2 = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "ui_auth_sessions_credentials",
+ keyvalues={
+ "session_id": session2,
+ "stage_type": LoginType.REGISTRATION_TOKEN,
+ },
+ retcol="result",
+ )
+ )
+ self.assertEquals(db_to_json(result2), token)
+
+ # Delete both sessions (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
+ # Check pending is now 0
+ pending = self.get_success(
+ store.db_pool.simple_select_one_onecol(
+ "registration_tokens",
+ keyvalues={"token": token},
+ retcol="pending",
+ )
+ )
+ self.assertEquals(pending, 0)
+
+ @override_config({"registration_requires_token": True})
+ def test_POST_registration_token_session_expiry_deleted_token(self):
+ """Test session expiry doesn't break when the token is deleted.
+
+ 1. Start but don't complete UIA with a registration token
+ 2. Delete the token from the database
+ 3. Expire the session
+ """
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ # Do request without auth to get a session ID
+ params = {"username": "kermit", "password": "monkey"}
+ channel = self.make_request(b"POST", self.url, json.dumps(params))
+ session = channel.json_body["session"]
+
+ # Use token
+ params["auth"] = {
+ "type": LoginType.REGISTRATION_TOKEN,
+ "token": token,
+ "session": session,
+ }
+ self.make_request(b"POST", self.url, json.dumps(params))
+
+ # Delete token
+ self.get_success(
+ store.db_pool.simple_delete_one(
+ "registration_tokens",
+ keyvalues={"token": token},
+ )
+ )
+
+ # Delete session (mimics expiry)
+ self.get_success(
+ store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+ )
+
def test_advertised_flows(self):
channel = self.make_request(b"POST", self.url, b"{}")
self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+ servlets = [register.register_servlets]
+ url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+ def default_config(self):
+ config = super().default_config()
+ config["registration_requires_token"] = True
+ return config
+
+ def test_GET_token_valid(self):
+ token = "abcd"
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.db_pool.simple_insert(
+ "registration_tokens",
+ {
+ "token": token,
+ "uses_allowed": None,
+ "pending": 0,
+ "completed": 0,
+ "expiry_time": None,
+ },
+ )
+ )
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], True)
+
+ def test_GET_token_invalid(self):
+ token = "1234"
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertEquals(channel.json_body["valid"], False)
+
+ @override_config(
+ {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+ )
+ def test_GET_ratelimiting(self):
+ token = "1234"
+
+ for i in range(0, 6):
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+
+ if i == 5:
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ retry_after_ms = int(channel.json_body["retry_after_ms"])
+ else:
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+ channel = self.make_request(
+ b"GET",
+ f"{self.url}?token={token}",
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
|