summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorCallum Brown <callum@calcuode.com>2021-08-21 22:14:43 +0100
committerGitHub <noreply@github.com>2021-08-21 22:14:43 +0100
commit947dbbdfd1e0029da66f956d277b7c089928e1e7 (patch)
tree57cf53bcbb1f02e75e114cde5d0aa77662163038 /tests/rest
parentFlatten tests/rest/client/{v1,v2_alpha} too (#10667) (diff)
downloadsynapse-947dbbdfd1e0029da66f956d277b7c089928e1e7.tar.xz
Implement MSC3231: Token authenticated registration (#10142)
Signed-off-by: Callum Brown <callum@calcuode.com>

This is part of my GSoC project implementing [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231).
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_registration_tokens.py710
-rw-r--r--tests/rest/client/test_register.py434
2 files changed, 1144 insertions, 0 deletions
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.url = "/_synapse/admin/v1/registration_tokens"
+
+    def _new_token(self, **kwargs):
+        """Helper function to create a token."""
+        token = kwargs.get(
+            "token",
+            "".join(random.choices(string.ascii_letters, k=8)),
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": kwargs.get("uses_allowed", None),
+                    "pending": kwargs.get("pending", 0),
+                    "completed": kwargs.get("completed", 0),
+                    "expiry_time": kwargs.get("expiry_time", None),
+                },
+            )
+        )
+        return token
+
+    # CREATION
+
+    def test_create_no_auth(self):
+        """Try to create a token without authentication."""
+        channel = self.make_request("POST", self.url + "/new", {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_create_requester_not_admin(self):
+        """Try to create a token while not an admin."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_create_using_defaults(self):
+        """Create a token using all the defaults."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_specifying_fields(self):
+        """Create a token specifying the value of all fields."""
+        data = {
+            "token": "abcd",
+            "uses_allowed": 1,
+            "expiry_time": self.clock.time_msec() + 1000000,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], "abcd")
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_with_null_value(self):
+        """Create a token specifying unlimited uses and no expiry."""
+        data = {
+            "uses_allowed": None,
+            "expiry_time": None,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_token_too_long(self):
+        """Check token longer than 64 chars is invalid."""
+        data = {"token": "a" * 65}
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_invalid_chars(self):
+        """Check you can't create token with invalid characters."""
+        data = {
+            "token": "abc/def",
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_already_exists(self):
+        """Check you can't create token that already exists."""
+        data = {
+            "token": "abcd",
+        }
+
+        channel1 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+        channel2 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+        self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_unable_to_generate_token(self):
+        """Check right error is raised when server can't generate unique token."""
+        # Create all possible single character tokens
+        tokens = []
+        for c in string.ascii_letters + string.digits + "-_":
+            tokens.append(
+                {
+                    "token": c,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                }
+            )
+        self.get_success(
+            self.store.db_pool.simple_insert_many(
+                "registration_tokens",
+                tokens,
+                "create_all_registration_tokens",
+            )
+        )
+
+        # Check creating a single character token fails with a 500 status code
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+    def test_create_uses_allowed(self):
+        """Check you can only create a token with good values for uses_allowed."""
+        # Should work with 0 (token is invalid from the start)
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+        # Should fail with negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_expiry_time(self):
+        """Check you can't create a token with an invalid expiry_time."""
+        # Should fail with a time in the past
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() - 10000},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() + 1000000.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_length(self):
+        """Check you can only generate a token with a valid length."""
+        # Should work with 64
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 64},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 64)
+
+        # Should fail with 0
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 8.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with 65
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 65},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # UPDATING
+
+    def test_update_no_auth(self):
+        """Try to update a token without authentication."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_update_requester_not_admin(self):
+        """Try to update a token while not an admin."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_update_non_existent(self):
+        """Try to update a token that doesn't exist."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_update_uses_allowed(self):
+        """Test updating just uses_allowed."""
+        # Create new token using default values
+        token = self._new_token()
+
+        # Should succeed with 1
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with 0 (makes token invalid)
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should fail with a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_expiry_time(self):
+        """Test updating just expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        # Should succeed with a time in the future
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should fail with a time in the past
+        past_time = self.clock.time_msec() - 10000
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": past_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time + 0.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_both(self):
+        """Test updating both uses_allowed and expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        data = {
+            "uses_allowed": 1,
+            "expiry_time": new_expiry_time,
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+    def test_update_invalid_type(self):
+        """Test using invalid types doesn't work."""
+        # Create new token using default values
+        token = self._new_token()
+
+        data = {
+            "uses_allowed": False,
+            "expiry_time": "1626430124000",
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # DELETING
+
+    def test_delete_no_auth(self):
+        """Try to delete a token without authentication."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_delete_requester_not_admin(self):
+        """Try to delete a token while not an admin."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_delete_non_existent(self):
+        """Try to delete a token that doesn't exist."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_delete(self):
+        """Test deleting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    # GETTING ONE
+
+    def test_get_no_auth(self):
+        """Try to get a token without authentication."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_get_requester_not_admin(self):
+        """Try to get a token while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_get_non_existent(self):
+        """Try to get a token that doesn't exist."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_get(self):
+        """Test getting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], token)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    # LISTING
+
+    def test_list_no_auth(self):
+        """Try to list tokens without authentication."""
+        channel = self.make_request("GET", self.url, {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_list_requester_not_admin(self):
+        """Try to list tokens while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_list_all(self):
+        """Test listing all tokens."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+        token_info = channel.json_body["registration_tokens"][0]
+        self.assertEqual(token_info["token"], token)
+        self.assertIsNone(token_info["uses_allowed"])
+        self.assertIsNone(token_info["expiry_time"])
+        self.assertEqual(token_info["pending"], 0)
+        self.assertEqual(token_info["completed"], 0)
+
+    def test_list_invalid_query_parameter(self):
+        """Test with `valid` query parameter not `true` or `false`."""
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=x",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+    def _test_list_query_parameter(self, valid: str):
+        """Helper used to test both valid=true and valid=false."""
+        # Create 2 valid and 2 invalid tokens.
+        now = self.hs.get_clock().time_msec()
+        # Create always valid token
+        valid1 = self._new_token()
+        # Create token that hasn't been used up
+        valid2 = self._new_token(uses_allowed=1)
+        # Create token that has expired
+        invalid1 = self._new_token(expiry_time=now - 10000)
+        # Create token that has been used up but hasn't expired
+        invalid2 = self._new_token(
+            uses_allowed=2,
+            pending=1,
+            completed=1,
+            expiry_time=now + 1000000,
+        )
+
+        if valid == "true":
+            tokens = [valid1, valid2]
+        else:
+            tokens = [invalid1, invalid2]
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=" + valid,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+        token_info_1 = channel.json_body["registration_tokens"][0]
+        token_info_2 = channel.json_body["registration_tokens"][1]
+        self.assertIn(token_info_1["token"], tokens)
+        self.assertIn(token_info_2["token"], tokens)
+
+    def test_list_valid(self):
+        """Test listing just valid tokens."""
+        self._test_list_query_parameter(valid="true")
+
+    def test_list_invalid(self):
+        """Test listing just invalid tokens."""
+        self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
 
 from tests import unittest
 from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_requires_token(self):
+        username = "kermit"
+        device_id = "frogfone"
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params = {
+            "username": username,
+            "password": "monkey",
+            "device_id": device_id,
+        }
+
+        # Request without auth to get flows and session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+        # Synapse adds a dummy stage to differentiate flows where otherwise one
+        # flow would be a subset of another flow.
+        self.assertCountEqual(
+            [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+            (f["stages"] for f in flows),
+        )
+        session = channel.json_body["session"]
+
+        # Do the registration token stage and check it has completed
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+        # Do the m.login.dummy stage and check registration was successful
+        params["auth"] = {
+            "type": LoginType.DUMMY,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        det_data = {
+            "user_id": f"@{username}:{self.hs.hostname}",
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+        # Check the `completed` counter has been incremented and pending is 0
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["completed"], 1)
+        self.assertEquals(res["pending"], 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_invalid(self):
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+        }
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Test with token param missing (invalid)
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with non-string (invalid)
+        params["auth"]["token"] = 1234
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with unknown token (invalid)
+        params["auth"]["token"] = "1234"
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_limit_uses(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        # Create token that can be used once
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": 1,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        # Do 2 requests without auth to get two session IDs
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with session1 and check `pending` is 1
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Repeat request to make sure pending isn't increased again
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 1)
+
+        # Check auth fails when using token with session2
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Check pending=0 and completed=1
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["pending"], 0)
+        self.assertEquals(res["completed"], 1)
+
+        # Check auth still fails when using token with session2
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_expiry(self):
+        token = "abcd"
+        now = self.hs.get_clock().time_msec()
+        store = self.hs.get_datastore()
+        # Create token that expired yesterday
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": now - 24 * 60 * 60 * 1000,
+                },
+            )
+        )
+        params = {"username": "kermit", "password": "monkey"}
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Check authentication fails with expired token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Update token so it expires tomorrow
+        self.get_success(
+            store.db_pool.simple_update_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+            )
+        )
+
+        # Check authentication succeeds
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry(self):
+        """Test `pending` is decremented when an uncompleted session expires."""
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do 2 requests without auth to get two session IDs
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with both sessions
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params2))
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        # Check `result` of registration token stage for session1 is `True`
+        result1 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session1,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertTrue(db_to_json(result1))
+
+        # Check `result` for session2 is the token used
+        result2 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session2,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertEquals(db_to_json(result2), token)
+
+        # Delete both sessions (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
+        # Check pending is now 0
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry_deleted_token(self):
+        """Test session expiry doesn't break when the token is deleted.
+
+        1. Start but don't complete UIA with a registration token
+        2. Delete the token from the database
+        3. Expire the session
+        """
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do request without auth to get a session ID
+        params = {"username": "kermit", "password": "monkey"}
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Use token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params))
+
+        # Delete token
+        self.get_success(
+            store.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+            )
+        )
+
+        # Delete session (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
     def test_advertised_flows(self):
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+    def default_config(self):
+        config = super().default_config()
+        config["registration_requires_token"] = True
+        return config
+
+    def test_GET_token_valid(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], True)
+
+    def test_GET_token_invalid(self):
+        token = "1234"
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], False)
+
+    @override_config(
+        {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+    )
+    def test_GET_ratelimiting(self):
+        token = "1234"
+
+        for i in range(0, 6):
+            channel = self.make_request(
+                b"GET",
+                f"{self.url}?token={token}",
+            )
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)