summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_account.py (renamed from tests/rest/client/v2_alpha/test_account.py)0
-rw-r--r--tests/rest/client/test_auth.py (renamed from tests/rest/client/v2_alpha/test_auth.py)2
-rw-r--r--tests/rest/client/test_capabilities.py (renamed from tests/rest/client/v2_alpha/test_capabilities.py)109
-rw-r--r--tests/rest/client/test_directory.py (renamed from tests/rest/client/v1/test_directory.py)0
-rw-r--r--tests/rest/client/test_events.py (renamed from tests/rest/client/v1/test_events.py)0
-rw-r--r--tests/rest/client/test_filter.py (renamed from tests/rest/client/v2_alpha/test_filter.py)0
-rw-r--r--tests/rest/client/test_keys.py91
-rw-r--r--tests/rest/client/test_login.py (renamed from tests/rest/client/v1/test_login.py)2
-rw-r--r--tests/rest/client/test_password_policy.py (renamed from tests/rest/client/v2_alpha/test_password_policy.py)0
-rw-r--r--tests/rest/client/test_presence.py (renamed from tests/rest/client/v1/test_presence.py)0
-rw-r--r--tests/rest/client/test_profile.py (renamed from tests/rest/client/v1/test_profile.py)0
-rw-r--r--tests/rest/client/test_push_rule_attrs.py (renamed from tests/rest/client/v1/test_push_rule_attrs.py)0
-rw-r--r--tests/rest/client/test_register.py (renamed from tests/rest/client/v2_alpha/test_register.py)434
-rw-r--r--tests/rest/client/test_relations.py (renamed from tests/rest/client/v2_alpha/test_relations.py)0
-rw-r--r--tests/rest/client/test_report_event.py (renamed from tests/rest/client/v2_alpha/test_report_event.py)0
-rw-r--r--tests/rest/client/test_rooms.py (renamed from tests/rest/client/v1/test_rooms.py)0
-rw-r--r--tests/rest/client/test_sendtodevice.py (renamed from tests/rest/client/v2_alpha/test_sendtodevice.py)0
-rw-r--r--tests/rest/client/test_shared_rooms.py (renamed from tests/rest/client/v2_alpha/test_shared_rooms.py)0
-rw-r--r--tests/rest/client/test_sync.py (renamed from tests/rest/client/v2_alpha/test_sync.py)0
-rw-r--r--tests/rest/client/test_typing.py (renamed from tests/rest/client/v1/test_typing.py)0
-rw-r--r--tests/rest/client/test_upgrade_room.py (renamed from tests/rest/client/v2_alpha/test_upgrade_room.py)0
-rw-r--r--tests/rest/client/utils.py (renamed from tests/rest/client/v1/utils.py)0
-rw-r--r--tests/rest/client/v1/__init__.py13
-rw-r--r--tests/rest/client/v2_alpha/__init__.py0
24 files changed, 613 insertions, 38 deletions
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..b946fca8b3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/test_account.py
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
index cf5cfb910c..e2fcbdc63a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -25,7 +25,7 @@ from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 13b3c5f499..422361b62a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.url = b"/_matrix/client/r0/capabilities"
         hs = self.setup_test_homeserver()
-        self.store = hs.get_datastore()
         self.config = hs.config
         self.auth_handler = hs.get_auth_handler()
         return hs
 
+    def prepare(self, reactor, clock, hs):
+        self.localpart = "user"
+        self.password = "pass"
+        self.user = self.register_user(self.localpart, self.password)
+
     def test_check_auth_required(self):
         channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
     def test_get_room_version_capabilities(self):
-        self.register_user("user", "pass")
-        access_token = self.login("user", "pass")
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_change_password_capabilities_password_login(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
-        access_token = self.login(user, password)
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"localdb_enabled": False}})
     def test_get_change_password_capabilities_localdb_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"enabled": False}})
     def test_get_change_password_capabilities_password_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -102,14 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertFalse(capabilities["m.change_password"]["enabled"])
 
+    def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+        """Test that per default msc3283 is disabled server returns `m.change_password`."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+        self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+        self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+    @override_config({"experimental_features": {"msc3283_enabled": True}})
+    def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+        """Test if msc3283 is enabled server returns capabilities."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_displayname": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_displayname_capabilities_displayname_disabled(self):
+        """Test if set displayname is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_avatar_url": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+        """Test if set avatar_url is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+    @override_config(
+        {
+            "enable_3pid_changes": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_change_3pid_capabilities_3pid_disabled(self):
+        """Test if change 3pid is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
     @override_config({"experimental_features": {"msc3244_enabled": False}})
     def test_get_does_not_include_msc3244_fields_when_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_does_include_msc3244_fields_when_enabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..d2181ea907 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/test_directory.py
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
index a90294003e..a90294003e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/test_events.py
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..475c6bed3d 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/test_filter.py
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 0000000000..d7fa635eae
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        keys.register_servlets,
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+    ]
+
+    def test_rejects_device_id_ice_key_outside_of_list(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: "device_id1",
+                },
+            },
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_rejects_device_key_given_as_map_to_bool(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: {
+                        "device_id1": True,
+                    },
+                },
+            },
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_requires_device_key(self):
+        """`device_keys` is required. We should complain if it's missing."""
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {},
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
index eba3552b19..5b2243fe52 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -32,7 +32,7 @@ from synapse.types import create_requester
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3cf5871899..3cf5871899 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..1d152352d1 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/test_presence.py
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..2860579c2e 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/test_profile.py
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d0ce91ccd9..d0ce91ccd9 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
 
 from tests import unittest
 from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_requires_token(self):
+        username = "kermit"
+        device_id = "frogfone"
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params = {
+            "username": username,
+            "password": "monkey",
+            "device_id": device_id,
+        }
+
+        # Request without auth to get flows and session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+        # Synapse adds a dummy stage to differentiate flows where otherwise one
+        # flow would be a subset of another flow.
+        self.assertCountEqual(
+            [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+            (f["stages"] for f in flows),
+        )
+        session = channel.json_body["session"]
+
+        # Do the registration token stage and check it has completed
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+        # Do the m.login.dummy stage and check registration was successful
+        params["auth"] = {
+            "type": LoginType.DUMMY,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        det_data = {
+            "user_id": f"@{username}:{self.hs.hostname}",
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+        # Check the `completed` counter has been incremented and pending is 0
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["completed"], 1)
+        self.assertEquals(res["pending"], 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_invalid(self):
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+        }
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Test with token param missing (invalid)
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with non-string (invalid)
+        params["auth"]["token"] = 1234
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with unknown token (invalid)
+        params["auth"]["token"] = "1234"
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_limit_uses(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        # Create token that can be used once
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": 1,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        # Do 2 requests without auth to get two session IDs
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with session1 and check `pending` is 1
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Repeat request to make sure pending isn't increased again
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 1)
+
+        # Check auth fails when using token with session2
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Check pending=0 and completed=1
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["pending"], 0)
+        self.assertEquals(res["completed"], 1)
+
+        # Check auth still fails when using token with session2
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_expiry(self):
+        token = "abcd"
+        now = self.hs.get_clock().time_msec()
+        store = self.hs.get_datastore()
+        # Create token that expired yesterday
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": now - 24 * 60 * 60 * 1000,
+                },
+            )
+        )
+        params = {"username": "kermit", "password": "monkey"}
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Check authentication fails with expired token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Update token so it expires tomorrow
+        self.get_success(
+            store.db_pool.simple_update_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+            )
+        )
+
+        # Check authentication succeeds
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry(self):
+        """Test `pending` is decremented when an uncompleted session expires."""
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do 2 requests without auth to get two session IDs
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with both sessions
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params2))
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        # Check `result` of registration token stage for session1 is `True`
+        result1 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session1,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertTrue(db_to_json(result1))
+
+        # Check `result` for session2 is the token used
+        result2 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session2,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertEquals(db_to_json(result2), token)
+
+        # Delete both sessions (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
+        # Check pending is now 0
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry_deleted_token(self):
+        """Test session expiry doesn't break when the token is deleted.
+
+        1. Start but don't complete UIA with a registration token
+        2. Delete the token from the database
+        3. Expire the session
+        """
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do request without auth to get a session ID
+        params = {"username": "kermit", "password": "monkey"}
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Use token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params))
+
+        # Delete token
+        self.get_success(
+            store.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+            )
+        )
+
+        # Delete session (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
     def test_advertised_flows(self):
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+    def default_config(self):
+        config = super().default_config()
+        config["registration_requires_token"] = True
+        return config
+
+    def test_GET_token_valid(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], True)
+
+    def test_GET_token_invalid(self):
+        token = "1234"
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], False)
+
+    @override_config(
+        {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+    )
+    def test_GET_ratelimiting(self):
+        token = "1234"
+
+        for i in range(0, 6):
+            channel = self.make_request(
+                b"GET",
+                f"{self.url}?token={token}",
+            )
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..02b5e9a8d0 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/test_relations.py
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
index ee6b0b9ebf..ee6b0b9ebf 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
index 0c9cbb9aff..0c9cbb9aff 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 6db7062a8e..6db7062a8e 100644
--- a/tests/rest/client/v2_alpha/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..283eccd53f 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..95be369d4b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/test_sync.py
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..b54b004733 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/test_typing.py
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 72f976d8e2..72f976d8e2 100644
--- a/tests/rest/client/v2_alpha/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..954ad1a1fd 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/utils.py
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed..0000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/rest/client/v2_alpha/__init__.py
+++ /dev/null