summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_presence_router.py109
-rw-r--r--tests/handlers/test_sync.py137
-rw-r--r--tests/replication/slave/storage/test_events.py10
-rw-r--r--tests/rest/admin/test_device.py99
-rw-r--r--tests/rest/admin/test_registration_tokens.py710
-rw-r--r--tests/rest/admin/test_room.py162
-rw-r--r--tests/rest/admin/test_user.py62
-rw-r--r--tests/rest/client/test_account.py (renamed from tests/rest/client/v2_alpha/test_account.py)0
-rw-r--r--tests/rest/client/test_auth.py (renamed from tests/rest/client/v2_alpha/test_auth.py)2
-rw-r--r--tests/rest/client/test_capabilities.py (renamed from tests/rest/client/v2_alpha/test_capabilities.py)109
-rw-r--r--tests/rest/client/test_directory.py (renamed from tests/rest/client/v1/test_directory.py)0
-rw-r--r--tests/rest/client/test_events.py (renamed from tests/rest/client/v1/test_events.py)0
-rw-r--r--tests/rest/client/test_filter.py (renamed from tests/rest/client/v2_alpha/test_filter.py)0
-rw-r--r--tests/rest/client/test_keys.py91
-rw-r--r--tests/rest/client/test_login.py (renamed from tests/rest/client/v1/test_login.py)2
-rw-r--r--tests/rest/client/test_password_policy.py (renamed from tests/rest/client/v2_alpha/test_password_policy.py)0
-rw-r--r--tests/rest/client/test_presence.py (renamed from tests/rest/client/v1/test_presence.py)0
-rw-r--r--tests/rest/client/test_profile.py (renamed from tests/rest/client/v1/test_profile.py)0
-rw-r--r--tests/rest/client/test_push_rule_attrs.py (renamed from tests/rest/client/v1/test_push_rule_attrs.py)0
-rw-r--r--tests/rest/client/test_register.py (renamed from tests/rest/client/v2_alpha/test_register.py)434
-rw-r--r--tests/rest/client/test_relations.py (renamed from tests/rest/client/v2_alpha/test_relations.py)0
-rw-r--r--tests/rest/client/test_report_event.py (renamed from tests/rest/client/v2_alpha/test_report_event.py)0
-rw-r--r--tests/rest/client/test_rooms.py (renamed from tests/rest/client/v1/test_rooms.py)0
-rw-r--r--tests/rest/client/test_sendtodevice.py (renamed from tests/rest/client/v2_alpha/test_sendtodevice.py)0
-rw-r--r--tests/rest/client/test_shared_rooms.py (renamed from tests/rest/client/v2_alpha/test_shared_rooms.py)0
-rw-r--r--tests/rest/client/test_sync.py (renamed from tests/rest/client/v2_alpha/test_sync.py)0
-rw-r--r--tests/rest/client/test_typing.py (renamed from tests/rest/client/v1/test_typing.py)0
-rw-r--r--tests/rest/client/test_upgrade_room.py (renamed from tests/rest/client/v2_alpha/test_upgrade_room.py)0
-rw-r--r--tests/rest/client/utils.py (renamed from tests/rest/client/v1/utils.py)0
-rw-r--r--tests/rest/client/v1/__init__.py13
-rw-r--r--tests/rest/client/v2_alpha/__init__.py0
-rw-r--r--tests/test_federation.py10
-rw-r--r--tests/unittest.py2
33 files changed, 1644 insertions, 308 deletions
diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py
index 6b87f571b8..3b3866bff8 100644
--- a/tests/events/test_presence_router.py
+++ b/tests/events/test_presence_router.py
@@ -17,7 +17,7 @@ from unittest.mock import Mock
 import attr
 
 from synapse.api.constants import EduTypes
-from synapse.events.presence_router import PresenceRouter
+from synapse.events.presence_router import PresenceRouter, load_legacy_presence_router
 from synapse.federation.units import Transaction
 from synapse.handlers.presence import UserPresenceState
 from synapse.module_api import ModuleApi
@@ -34,7 +34,7 @@ class PresenceRouterTestConfig:
     users_who_should_receive_all_presence = attr.ib(type=List[str], default=[])
 
 
-class PresenceRouterTestModule:
+class LegacyPresenceRouterTestModule:
     def __init__(self, config: PresenceRouterTestConfig, module_api: ModuleApi):
         self._config = config
         self._module_api = module_api
@@ -77,6 +77,53 @@ class PresenceRouterTestModule:
         return config
 
 
+class PresenceRouterTestModule:
+    def __init__(self, config: PresenceRouterTestConfig, api: ModuleApi):
+        self._config = config
+        self._module_api = api
+        api.register_presence_router_callbacks(
+            get_users_for_states=self.get_users_for_states,
+            get_interested_users=self.get_interested_users,
+        )
+
+    async def get_users_for_states(
+        self, state_updates: Iterable[UserPresenceState]
+    ) -> Dict[str, Set[UserPresenceState]]:
+        users_to_state = {
+            user_id: set(state_updates)
+            for user_id in self._config.users_who_should_receive_all_presence
+        }
+        return users_to_state
+
+    async def get_interested_users(
+        self, user_id: str
+    ) -> Union[Set[str], PresenceRouter.ALL_USERS]:
+        if user_id in self._config.users_who_should_receive_all_presence:
+            return PresenceRouter.ALL_USERS
+
+        return set()
+
+    @staticmethod
+    def parse_config(config_dict: dict) -> PresenceRouterTestConfig:
+        """Parse a configuration dictionary from the homeserver config, do
+        some validation and return a typed PresenceRouterConfig.
+
+        Args:
+            config_dict: The configuration dictionary.
+
+        Returns:
+            A validated config object.
+        """
+        # Initialise a typed config object
+        config = PresenceRouterTestConfig()
+
+        config.users_who_should_receive_all_presence = config_dict.get(
+            "users_who_should_receive_all_presence"
+        )
+
+        return config
+
+
 class PresenceRouterTestCase(FederatingHomeserverTestCase):
     servlets = [
         admin.register_servlets,
@@ -86,9 +133,17 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        return self.setup_test_homeserver(
+        hs = self.setup_test_homeserver(
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
+        # Load the modules into the homeserver
+        module_api = hs.get_module_api()
+        for module, config in hs.config.modules.loaded_modules:
+            module(config=config, api=module_api)
+
+        load_legacy_presence_router(hs)
+
+        return hs
 
     def prepare(self, reactor, clock, homeserver):
         self.sync_handler = self.hs.get_sync_handler()
@@ -98,7 +153,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         {
             "presence": {
                 "presence_router": {
-                    "module": __name__ + ".PresenceRouterTestModule",
+                    "module": __name__ + ".LegacyPresenceRouterTestModule",
                     "config": {
                         "users_who_should_receive_all_presence": [
                             "@presence_gobbler:test",
@@ -109,7 +164,28 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             "send_federation": True,
         }
     )
+    def test_receiving_all_presence_legacy(self):
+        self.receiving_all_presence_test_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".PresenceRouterTestModule",
+                    "config": {
+                        "users_who_should_receive_all_presence": [
+                            "@presence_gobbler:test",
+                        ]
+                    },
+                },
+            ],
+            "send_federation": True,
+        }
+    )
     def test_receiving_all_presence(self):
+        self.receiving_all_presence_test_body()
+
+    def receiving_all_presence_test_body(self):
         """Test that a user that does not share a room with another other can receive
         presence for them, due to presence routing.
         """
@@ -203,7 +279,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
         {
             "presence": {
                 "presence_router": {
-                    "module": __name__ + ".PresenceRouterTestModule",
+                    "module": __name__ + ".LegacyPresenceRouterTestModule",
                     "config": {
                         "users_who_should_receive_all_presence": [
                             "@presence_gobbler1:test",
@@ -216,7 +292,30 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
             "send_federation": True,
         }
     )
+    def test_send_local_online_presence_to_with_module_legacy(self):
+        self.send_local_online_presence_to_with_module_test_body()
+
+    @override_config(
+        {
+            "modules": [
+                {
+                    "module": __name__ + ".PresenceRouterTestModule",
+                    "config": {
+                        "users_who_should_receive_all_presence": [
+                            "@presence_gobbler1:test",
+                            "@presence_gobbler2:test",
+                            "@far_away_person:island",
+                        ]
+                    },
+                },
+            ],
+            "send_federation": True,
+        }
+    )
     def test_send_local_online_presence_to_with_module(self):
+        self.send_local_online_presence_to_with_module_test_body()
+
+    def send_local_online_presence_to_with_module_test_body(self):
         """Tests that send_local_presence_to_users sends local online presence to a set
         of specified local and remote users, with a custom PresenceRouter module enabled.
         """
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 84f05f6c58..339c039914 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,9 +12,16 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import Optional
+
+from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
+from synapse.api.room_versions import RoomVersions
 from synapse.handlers.sync import SyncConfig
+from synapse.rest import admin
+from synapse.rest.client import knock, login, room
+from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 
 import tests.unittest
@@ -24,8 +31,14 @@ import tests.utils
 class SyncTestCase(tests.unittest.HomeserverTestCase):
     """Tests Sync Handler."""
 
-    def prepare(self, reactor, clock, hs):
-        self.hs = hs
+    servlets = [
+        admin.register_servlets,
+        knock.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs: HomeServer):
         self.sync_handler = self.hs.get_sync_handler()
         self.store = self.hs.get_datastore()
 
@@ -68,12 +81,124 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
         )
         self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
 
+    def test_unknown_room_version(self):
+        """
+        A room with an unknown room version should not break sync (and should be excluded).
+        """
+        inviter = self.register_user("creator", "pass", admin=True)
+        inviter_tok = self.login("@creator:test", "pass")
+
+        user = self.register_user("user", "pass")
+        tok = self.login("user", "pass")
+
+        # Do an initial sync on a different device.
+        requester = create_requester(user)
+        initial_result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user, device_id="dev")
+            )
+        )
+
+        # Create a room as the user.
+        joined_room = self.helper.create_room_as(user, tok=tok)
+
+        # Invite the user to the room as someone else.
+        invite_room = self.helper.create_room_as(inviter, tok=inviter_tok)
+        self.helper.invite(invite_room, targ=user, tok=inviter_tok)
+
+        knock_room = self.helper.create_room_as(
+            inviter, room_version=RoomVersions.V7.identifier, tok=inviter_tok
+        )
+        self.helper.send_state(
+            knock_room,
+            EventTypes.JoinRules,
+            {"join_rule": JoinRules.KNOCK},
+            tok=inviter_tok,
+        )
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/knock/%s" % (knock_room,),
+            b"{}",
+            tok,
+        )
+        self.assertEquals(200, channel.code, channel.result)
+
+        # The rooms should appear in the sync response.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user)
+            )
+        )
+        self.assertIn(joined_room, [r.room_id for r in result.joined])
+        self.assertIn(invite_room, [r.room_id for r in result.invited])
+        self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+        # Test a incremental sync (by providing a since_token).
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config=generate_sync_config(user, device_id="dev"),
+                since_token=initial_result.next_batch,
+            )
+        )
+        self.assertIn(joined_room, [r.room_id for r in result.joined])
+        self.assertIn(invite_room, [r.room_id for r in result.invited])
+        self.assertIn(knock_room, [r.room_id for r in result.knocked])
+
+        # Poke the database and update the room version to an unknown one.
+        for room_id in (joined_room, invite_room, knock_room):
+            self.get_success(
+                self.hs.get_datastores().main.db_pool.simple_update(
+                    "rooms",
+                    keyvalues={"room_id": room_id},
+                    updatevalues={"room_version": "unknown-room-version"},
+                    desc="updated-room-version",
+                )
+            )
+
+        # Blow away caches (supported room versions can only change due to a restart).
+        self.get_success(
+            self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+        )
+        self.store._get_event_cache.clear()
+
+        # The rooms should be excluded from the sync response.
+        # Get a new request key.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester, sync_config=generate_sync_config(user)
+            )
+        )
+        self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+        self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+        self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+        # The rooms should also not be in an incremental sync.
+        result = self.get_success(
+            self.sync_handler.wait_for_sync_for_user(
+                requester,
+                sync_config=generate_sync_config(user, device_id="dev"),
+                since_token=initial_result.next_batch,
+            )
+        )
+        self.assertNotIn(joined_room, [r.room_id for r in result.joined])
+        self.assertNotIn(invite_room, [r.room_id for r in result.invited])
+        self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+
+
+_request_key = 0
+
 
-def generate_sync_config(user_id: str) -> SyncConfig:
+def generate_sync_config(
+    user_id: str, device_id: Optional[str] = "device_id"
+) -> SyncConfig:
+    """Generate a sync config (with a unique request key)."""
+    global _request_key
+    _request_key += 1
     return SyncConfig(
-        user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
+        user=UserID.from_string(user_id),
         filter_collection=DEFAULT_FILTER_COLLECTION,
         is_guest=False,
-        request_key="request_key",
-        device_id="device_id",
+        request_key=("request_key", _request_key),
+        device_id=device_id,
     )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index db80a0bdbd..b25a06b427 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -20,7 +20,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
 from synapse.handlers.room import RoomEventSource
 from synapse.replication.slave.storage.events import SlavedEventStore
-from synapse.storage.roommember import RoomsForUser
+from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser
 from synapse.types import PersistedEventPosition
 
 from tests.server import FakeTransport
@@ -150,6 +150,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                     "invite",
                     event.event_id,
                     event.internal_metadata.stream_ordering,
+                    RoomVersions.V1.identifier,
                 )
             ],
         )
@@ -216,7 +217,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
         self.check(
             "get_rooms_for_user_with_stream_ordering",
             (USER_ID_2,),
-            {(ROOM_ID, expected_pos)},
+            {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
         )
 
     def test_get_rooms_for_user_with_stream_ordering_with_multi_event_persist(self):
@@ -305,7 +306,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
                 expected_pos = PersistedEventPosition(
                     "master", j2.internal_metadata.stream_ordering
                 )
-                self.assertEqual(joined_rooms, {(ROOM_ID, expected_pos)})
+                self.assertEqual(
+                    joined_rooms,
+                    {GetRoomsForUserWithStreamOrdering(ROOM_ID, expected_pos)},
+                )
 
     event_id = 0
 
diff --git a/tests/rest/admin/test_device.py b/tests/rest/admin/test_device.py
index c4afe5c3d9..a3679be205 100644
--- a/tests/rest/admin/test_device.py
+++ b/tests/rest/admin/test_device.py
@@ -12,9 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import urllib.parse
 
+from parameterized import parameterized
+
 import synapse.rest.admin
 from synapse.api.errors import Codes
 from synapse.rest.client import login
@@ -45,49 +46,23 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             self.other_user_device_id,
         )
 
-    def test_no_auth(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_no_auth(self, method: str):
         """
         Try to get a device of an user without authentication.
         """
-        channel = self.make_request("GET", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("PUT", self.url, b"{}")
-
-        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
-        channel = self.make_request("DELETE", self.url, b"{}")
+        channel = self.make_request(method, self.url, b"{}")
 
         self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
 
-    def test_requester_is_no_admin(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_requester_is_no_admin(self, method: str):
         """
         If the user is not a server admin, an error is returned.
         """
         channel = self.make_request(
-            "GET",
-            self.url,
-            access_token=self.other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "PUT",
-            self.url,
-            access_token=self.other_user_token,
-        )
-
-        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
-        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             self.url,
             access_token=self.other_user_token,
         )
@@ -95,7 +70,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-    def test_user_does_not_exist(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_user_does_not_exist(self, method: str):
         """
         Tests that a lookup for a user that does not exist returns a 404
         """
@@ -105,7 +81,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "GET",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
@@ -113,25 +89,8 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(404, channel.code, msg=channel.json_body)
         self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
 
-        channel = self.make_request(
-            "PUT",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-        channel = self.make_request(
-            "DELETE",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(404, channel.code, msg=channel.json_body)
-        self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
-    def test_user_is_not_local(self):
+    @parameterized.expand(["GET", "PUT", "DELETE"])
+    def test_user_is_not_local(self, method: str):
         """
         Tests that a lookup for a user that is not a local returns a 400
         """
@@ -141,25 +100,7 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         )
 
         channel = self.make_request(
-            "GET",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
-        channel = self.make_request(
-            "PUT",
-            url,
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(400, channel.code, msg=channel.json_body)
-        self.assertEqual("Can only lookup local users", channel.json_body["error"])
-
-        channel = self.make_request(
-            "DELETE",
+            method,
             url,
             access_token=self.admin_user_tok,
         )
@@ -219,12 +160,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
             * (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
         }
 
-        body = json.dumps(update)
         channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content=update,
         )
 
         self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -275,12 +215,11 @@ class DeviceRestTestCase(unittest.HomeserverTestCase):
         Tests a normal successful update of display name
         """
         # Set new display_name
-        body = json.dumps({"display_name": "new displayname"})
         channel = self.make_request(
             "PUT",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"display_name": "new displayname"},
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -529,12 +468,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
         """
         Tests that a remove of a device that does not exist returns 200.
         """
-        body = json.dumps({"devices": ["unknown_device1", "unknown_device2"]})
         channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"devices": ["unknown_device1", "unknown_device2"]},
         )
 
         # Delete unknown devices returns status 200
@@ -560,12 +498,11 @@ class DeleteDevicesRestTestCase(unittest.HomeserverTestCase):
             device_ids.append(str(d["device_id"]))
 
         # Delete devices
-        body = json.dumps({"devices": device_ids})
         channel = self.make_request(
             "POST",
             self.url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"devices": device_ids},
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py
new file mode 100644
index 0000000000..4927321e5a
--- /dev/null
+++ b/tests/rest/admin/test_registration_tokens.py
@@ -0,0 +1,710 @@
+# Copyright 2021 Callum Brown
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import random
+import string
+
+import synapse.rest.admin
+from synapse.api.errors import Codes
+from synapse.rest.client import login
+
+from tests import unittest
+
+
+class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.admin_user = self.register_user("admin", "pass", admin=True)
+        self.admin_user_tok = self.login("admin", "pass")
+
+        self.other_user = self.register_user("user", "pass")
+        self.other_user_tok = self.login("user", "pass")
+
+        self.url = "/_synapse/admin/v1/registration_tokens"
+
+    def _new_token(self, **kwargs):
+        """Helper function to create a token."""
+        token = kwargs.get(
+            "token",
+            "".join(random.choices(string.ascii_letters, k=8)),
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": kwargs.get("uses_allowed", None),
+                    "pending": kwargs.get("pending", 0),
+                    "completed": kwargs.get("completed", 0),
+                    "expiry_time": kwargs.get("expiry_time", None),
+                },
+            )
+        )
+        return token
+
+    # CREATION
+
+    def test_create_no_auth(self):
+        """Try to create a token without authentication."""
+        channel = self.make_request("POST", self.url + "/new", {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_create_requester_not_admin(self):
+        """Try to create a token while not an admin."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_create_using_defaults(self):
+        """Create a token using all the defaults."""
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_specifying_fields(self):
+        """Create a token specifying the value of all fields."""
+        data = {
+            "token": "abcd",
+            "uses_allowed": 1,
+            "expiry_time": self.clock.time_msec() + 1000000,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], "abcd")
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_with_null_value(self):
+        """Create a token specifying unlimited uses and no expiry."""
+        data = {
+            "uses_allowed": None,
+            "expiry_time": None,
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 16)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    def test_create_token_too_long(self):
+        """Check token longer than 64 chars is invalid."""
+        data = {"token": "a" * 65}
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_invalid_chars(self):
+        """Check you can't create token with invalid characters."""
+        data = {
+            "token": "abc/def",
+        }
+
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_token_already_exists(self):
+        """Check you can't create token that already exists."""
+        data = {
+            "token": "abcd",
+        }
+
+        channel1 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel1.result["code"]), msg=channel1.result["body"])
+
+        channel2 = self.make_request(
+            "POST",
+            self.url + "/new",
+            data,
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel2.result["code"]), msg=channel2.result["body"])
+        self.assertEqual(channel2.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_unable_to_generate_token(self):
+        """Check right error is raised when server can't generate unique token."""
+        # Create all possible single character tokens
+        tokens = []
+        for c in string.ascii_letters + string.digits + "-_":
+            tokens.append(
+                {
+                    "token": c,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                }
+            )
+        self.get_success(
+            self.store.db_pool.simple_insert_many(
+                "registration_tokens",
+                tokens,
+                "create_all_registration_tokens",
+            )
+        )
+
+        # Check creating a single character token fails with a 500 status code
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(500, int(channel.result["code"]), msg=channel.result["body"])
+
+    def test_create_uses_allowed(self):
+        """Check you can only create a token with good values for uses_allowed."""
+        # Should work with 0 (token is invalid from the start)
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+
+        # Should fail with negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_expiry_time(self):
+        """Check you can't create a token with an invalid expiry_time."""
+        # Should fail with a time in the past
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() - 10000},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"expiry_time": self.clock.time_msec() + 1000000.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_create_length(self):
+        """Check you can only generate a token with a valid length."""
+        # Should work with 64
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 64},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["token"]), 64)
+
+        # Should fail with 0
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a float
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 8.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with 65
+        channel = self.make_request(
+            "POST",
+            self.url + "/new",
+            {"length": 65},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # UPDATING
+
+    def test_update_no_auth(self):
+        """Try to update a token without authentication."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_update_requester_not_admin(self):
+        """Try to update a token while not an admin."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_update_non_existent(self):
+        """Try to update a token that doesn't exist."""
+        channel = self.make_request(
+            "PUT",
+            self.url + "/1234",
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_update_uses_allowed(self):
+        """Test updating just uses_allowed."""
+        # Create new token using default values
+        token = self._new_token()
+
+        # Should succeed with 1
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with 0 (makes token invalid)
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 0},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 0)
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+
+        # Should fail with a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": 1.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail with a negative integer
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"uses_allowed": -5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_expiry_time(self):
+        """Test updating just expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        # Should succeed with a time in the future
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should succeed with null
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": None},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertIsNone(channel.json_body["uses_allowed"])
+
+        # Should fail with a time in the past
+        past_time = self.clock.time_msec() - 10000
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": past_time},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+        # Should fail a float
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            {"expiry_time": new_expiry_time + 0.5},
+            access_token=self.admin_user_tok,
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    def test_update_both(self):
+        """Test updating both uses_allowed and expiry_time."""
+        # Create new token using default values
+        token = self._new_token()
+        new_expiry_time = self.clock.time_msec() + 1000000
+
+        data = {
+            "uses_allowed": 1,
+            "expiry_time": new_expiry_time,
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["uses_allowed"], 1)
+        self.assertEqual(channel.json_body["expiry_time"], new_expiry_time)
+
+    def test_update_invalid_type(self):
+        """Test using invalid types doesn't work."""
+        # Create new token using default values
+        token = self._new_token()
+
+        data = {
+            "uses_allowed": False,
+            "expiry_time": "1626430124000",
+        }
+
+        channel = self.make_request(
+            "PUT",
+            self.url + "/" + token,
+            data,
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
+
+    # DELETING
+
+    def test_delete_no_auth(self):
+        """Try to delete a token without authentication."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_delete_requester_not_admin(self):
+        """Try to delete a token while not an admin."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_delete_non_existent(self):
+        """Try to delete a token that doesn't exist."""
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_delete(self):
+        """Test deleting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "DELETE",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+
+    # GETTING ONE
+
+    def test_get_no_auth(self):
+        """Try to get a token without authentication."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+        )
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_get_requester_not_admin(self):
+        """Try to get a token while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",  # Token doesn't exist but that doesn't matter
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_get_non_existent(self):
+        """Try to get a token that doesn't exist."""
+        channel = self.make_request(
+            "GET",
+            self.url + "/1234",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
+
+    def test_get(self):
+        """Test getting a token."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url + "/" + token,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(channel.json_body["token"], token)
+        self.assertIsNone(channel.json_body["uses_allowed"])
+        self.assertIsNone(channel.json_body["expiry_time"])
+        self.assertEqual(channel.json_body["pending"], 0)
+        self.assertEqual(channel.json_body["completed"], 0)
+
+    # LISTING
+
+    def test_list_no_auth(self):
+        """Try to list tokens without authentication."""
+        channel = self.make_request("GET", self.url, {})
+        self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
+
+    def test_list_requester_not_admin(self):
+        """Try to list tokens while not an admin."""
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.other_user_tok,
+        )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
+
+    def test_list_all(self):
+        """Test listing all tokens."""
+        # Create new token using default values
+        token = self._new_token()
+
+        channel = self.make_request(
+            "GET",
+            self.url,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 1)
+        token_info = channel.json_body["registration_tokens"][0]
+        self.assertEqual(token_info["token"], token)
+        self.assertIsNone(token_info["uses_allowed"])
+        self.assertIsNone(token_info["expiry_time"])
+        self.assertEqual(token_info["pending"], 0)
+        self.assertEqual(token_info["completed"], 0)
+
+    def test_list_invalid_query_parameter(self):
+        """Test with `valid` query parameter not `true` or `false`."""
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=x",
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+    def _test_list_query_parameter(self, valid: str):
+        """Helper used to test both valid=true and valid=false."""
+        # Create 2 valid and 2 invalid tokens.
+        now = self.hs.get_clock().time_msec()
+        # Create always valid token
+        valid1 = self._new_token()
+        # Create token that hasn't been used up
+        valid2 = self._new_token(uses_allowed=1)
+        # Create token that has expired
+        invalid1 = self._new_token(expiry_time=now - 10000)
+        # Create token that has been used up but hasn't expired
+        invalid2 = self._new_token(
+            uses_allowed=2,
+            pending=1,
+            completed=1,
+            expiry_time=now + 1000000,
+        )
+
+        if valid == "true":
+            tokens = [valid1, valid2]
+        else:
+            tokens = [invalid1, invalid2]
+
+        channel = self.make_request(
+            "GET",
+            self.url + "?valid=" + valid,
+            {},
+            access_token=self.admin_user_tok,
+        )
+
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(len(channel.json_body["registration_tokens"]), 2)
+        token_info_1 = channel.json_body["registration_tokens"][0]
+        token_info_2 = channel.json_body["registration_tokens"][1]
+        self.assertIn(token_info_1["token"], tokens)
+        self.assertIn(token_info_2["token"], tokens)
+
+    def test_list_valid(self):
+        """Test listing just valid tokens."""
+        self._test_list_query_parameter(valid="true")
+
+    def test_list_invalid(self):
+        """Test listing just invalid tokens."""
+        self._test_list_query_parameter(valid="false")
diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py
index c9d4731017..40e032df7f 100644
--- a/tests/rest/admin/test_room.py
+++ b/tests/rest/admin/test_room.py
@@ -29,123 +29,6 @@ from tests import unittest
 """Tests admin REST events for /rooms paths."""
 
 
-class ShutdownRoomTestCase(unittest.HomeserverTestCase):
-    servlets = [
-        synapse.rest.admin.register_servlets_for_client_rest_resource,
-        login.register_servlets,
-        events.register_servlets,
-        room.register_servlets,
-        room.register_deprecated_servlets,
-    ]
-
-    def prepare(self, reactor, clock, hs):
-        self.event_creation_handler = hs.get_event_creation_handler()
-        hs.config.user_consent_version = "1"
-
-        consent_uri_builder = Mock()
-        consent_uri_builder.build_user_consent_uri.return_value = "http://example.com"
-        self.event_creation_handler._consent_uri_builder = consent_uri_builder
-
-        self.store = hs.get_datastore()
-
-        self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
-
-        self.other_user = self.register_user("user", "pass")
-        self.other_user_token = self.login("user", "pass")
-
-        # Mark the admin user as having consented
-        self.get_success(self.store.user_set_consent_version(self.admin_user, "1"))
-
-    def test_shutdown_room_consent(self):
-        """Test that we can shutdown rooms with local users who have not
-        yet accepted the privacy policy. This used to fail when we tried to
-        force part the user from the old room.
-        """
-        self.event_creation_handler._block_events_without_consent_error = None
-
-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
-        # Assert one user in room
-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
-        self.assertEqual([self.other_user], users_in_room)
-
-        # Enable require consent to send events
-        self.event_creation_handler._block_events_without_consent_error = "Error"
-
-        # Assert that the user is getting consent error
-        self.helper.send(
-            room_id, body="foo", tok=self.other_user_token, expect_code=403
-        )
-
-        # Test that the admin can still send shutdown
-        url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            json.dumps({"new_room_user_id": self.admin_user}),
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Assert there is now no longer anyone in the room
-        users_in_room = self.get_success(self.store.get_users_in_room(room_id))
-        self.assertEqual([], users_in_room)
-
-    def test_shutdown_room_block_peek(self):
-        """Test that a world_readable room can no longer be peeked into after
-        it has been shut down.
-        """
-
-        self.event_creation_handler._block_events_without_consent_error = None
-
-        room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_token)
-
-        # Enable world readable
-        url = "rooms/%s/state/m.room.history_visibility" % (room_id,)
-        channel = self.make_request(
-            "PUT",
-            url.encode("ascii"),
-            json.dumps({"history_visibility": "world_readable"}),
-            access_token=self.other_user_token,
-        )
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Test that the admin can still send shutdown
-        url = "/_synapse/admin/v1/shutdown_room/" + room_id
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            json.dumps({"new_room_user_id": self.admin_user}),
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Assert we can no longer peek into the room
-        self._assert_peek(room_id, expect_code=403)
-
-    def _assert_peek(self, room_id, expect_code):
-        """Assert that the admin user can (or cannot) peek into the room."""
-
-        url = "rooms/%s/initialSync" % (room_id,)
-        channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok
-        )
-        self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
-        )
-
-        url = "events?timeout=0&room_id=" + room_id
-        channel = self.make_request(
-            "GET", url.encode("ascii"), access_token=self.admin_user_tok
-        )
-        self.assertEqual(
-            expect_code, int(channel.result["code"]), msg=channel.result["body"]
-        )
-
-
 @parameterized_class(
     ("method", "url_template"),
     [
@@ -557,51 +440,6 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
         )
 
 
-class PurgeRoomTestCase(unittest.HomeserverTestCase):
-    """Test /purge_room admin API."""
-
-    servlets = [
-        synapse.rest.admin.register_servlets,
-        login.register_servlets,
-        room.register_servlets,
-    ]
-
-    def prepare(self, reactor, clock, hs):
-        self.store = hs.get_datastore()
-
-        self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
-
-    def test_purge_room(self):
-        room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
-
-        # All users have to have left the room.
-        self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok)
-
-        url = "/_synapse/admin/v1/purge_room"
-        channel = self.make_request(
-            "POST",
-            url.encode("ascii"),
-            {"room_id": room_id},
-            access_token=self.admin_user_tok,
-        )
-
-        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
-
-        # Test that the following tables have been purged of all rows related to the room.
-        for table in PURGE_TABLES:
-            count = self.get_success(
-                self.store.db_pool.simple_select_one_onecol(
-                    table=table,
-                    keyvalues={"room_id": room_id},
-                    retcol="COUNT(*)",
-                    desc="test_purge_room",
-                )
-            )
-
-            self.assertEqual(count, 0, msg=f"Rows not purged in {table}")
-
-
 class RoomTestCase(unittest.HomeserverTestCase):
     """Test /room admin API."""
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index ef77275238..ee204c404b 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1431,12 +1431,14 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual(1, len(channel.json_body["threepids"]))
         self.assertEqual(
             "external_id1", channel.json_body["external_ids"][0]["external_id"]
         )
         self.assertEqual(
             "auth_provider1", channel.json_body["external_ids"][0]["auth_provider"]
         )
+        self.assertEqual(1, len(channel.json_body["external_ids"]))
         self.assertFalse(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
         self._check_fields(channel.json_body)
@@ -1676,18 +1678,53 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         Test setting threepid for an other user.
         """
 
-        # Delete old and add new threepid to user
+        # Add two threepids to user
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content={"threepids": [{"medium": "email", "address": "bob3@bob.bob"}]},
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob1@bob.bob"},
+                    {"medium": "email", "address": "bob2@bob.bob"},
+                ],
+            },
         )
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
+        # result does not always have the same sort order, therefore it becomes sorted
+        sorted_result = sorted(
+            channel.json_body["threepids"], key=lambda k: k["address"]
+        )
+        self.assertEqual("email", sorted_result[0]["medium"])
+        self.assertEqual("bob1@bob.bob", sorted_result[0]["address"])
+        self.assertEqual("email", sorted_result[1]["medium"])
+        self.assertEqual("bob2@bob.bob", sorted_result[1]["address"])
+        self._check_fields(channel.json_body)
+
+        # Set a new and remove a threepid
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={
+                "threepids": [
+                    {"medium": "email", "address": "bob2@bob.bob"},
+                    {"medium": "email", "address": "bob3@bob.bob"},
+                ],
+            },
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+        self._check_fields(channel.json_body)
 
         # Get user
         channel = self.make_request(
@@ -1698,8 +1735,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["threepids"]))
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("bob2@bob.bob", channel.json_body["threepids"][0]["address"])
+        self.assertEqual("email", channel.json_body["threepids"][1]["medium"])
+        self.assertEqual("bob3@bob.bob", channel.json_body["threepids"][1]["address"])
+        self._check_fields(channel.json_body)
+
+        # Remove threepids
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"threepids": []},
+        )
+        self.assertEqual(200, channel.code, msg=channel.json_body)
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(0, len(channel.json_body["threepids"]))
+        self._check_fields(channel.json_body)
 
     def test_set_external_id(self):
         """
@@ -1778,6 +1831,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, channel.code, msg=channel.json_body)
         self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertEqual(2, len(channel.json_body["external_ids"]))
         self.assertEqual(
             channel.json_body["external_ids"],
             [
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/test_account.py
index b946fca8b3..b946fca8b3 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/test_account.py
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/test_auth.py
index cf5cfb910c..e2fcbdc63a 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -25,7 +25,7 @@ from synapse.types import JsonDict, UserID
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.unittest import override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/test_capabilities.py
index 13b3c5f499..422361b62a 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/test_capabilities.py
@@ -30,19 +30,22 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         self.url = b"/_matrix/client/r0/capabilities"
         hs = self.setup_test_homeserver()
-        self.store = hs.get_datastore()
         self.config = hs.config
         self.auth_handler = hs.get_auth_handler()
         return hs
 
+    def prepare(self, reactor, clock, hs):
+        self.localpart = "user"
+        self.password = "pass"
+        self.user = self.register_user(self.localpart, self.password)
+
     def test_check_auth_required(self):
         channel = self.make_request("GET", self.url)
 
         self.assertEqual(channel.code, 401)
 
     def test_get_room_version_capabilities(self):
-        self.register_user("user", "pass")
-        access_token = self.login("user", "pass")
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -57,10 +60,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_change_password_capabilities_password_login(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
-        access_token = self.login(user, password)
+        access_token = self.login(self.localpart, self.password)
 
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]
@@ -70,12 +70,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"localdb_enabled": False}})
     def test_get_change_password_capabilities_localdb_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -87,12 +84,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
 
     @override_config({"password_config": {"enabled": False}})
     def test_get_change_password_capabilities_password_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -102,14 +96,86 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         self.assertFalse(capabilities["m.change_password"]["enabled"])
 
+    def test_get_change_users_attributes_capabilities_when_msc3283_disabled(self):
+        """Test that per default msc3283 is disabled server returns `m.change_password`."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertNotIn("org.matrix.msc3283.set_displayname", capabilities)
+        self.assertNotIn("org.matrix.msc3283.set_avatar_url", capabilities)
+        self.assertNotIn("org.matrix.msc3283.3pid_changes", capabilities)
+
+    @override_config({"experimental_features": {"msc3283_enabled": True}})
+    def test_get_change_users_attributes_capabilities_when_msc3283_enabled(self):
+        """Test if msc3283 is enabled server returns capabilities."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertTrue(capabilities["m.change_password"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+        self.assertTrue(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_displayname": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_displayname_capabilities_displayname_disabled(self):
+        """Test if set displayname is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_displayname"]["enabled"])
+
+    @override_config(
+        {
+            "enable_set_avatar_url": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_get_set_avatar_url_capabilities_avatar_url_disabled(self):
+        """Test if set avatar_url is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.set_avatar_url"]["enabled"])
+
+    @override_config(
+        {
+            "enable_3pid_changes": False,
+            "experimental_features": {"msc3283_enabled": True},
+        }
+    )
+    def test_change_3pid_capabilities_3pid_disabled(self):
+        """Test if change 3pid is disabled that the server responds it."""
+        access_token = self.login(self.localpart, self.password)
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["org.matrix.msc3283.3pid_changes"]["enabled"])
+
     @override_config({"experimental_features": {"msc3244_enabled": False}})
     def test_get_does_not_include_msc3244_fields_when_disabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
@@ -122,12 +188,9 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_does_include_msc3244_fields_when_enabled(self):
-        localpart = "user"
-        password = "pass"
-        user = self.register_user(localpart, password)
         access_token = self.get_success(
             self.auth_handler.get_access_token_for_user_id(
-                user, device_id=None, valid_until_ms=None
+                self.user, device_id=None, valid_until_ms=None
             )
         )
 
diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/test_directory.py
index d2181ea907..d2181ea907 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/test_directory.py
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/test_events.py
index a90294003e..a90294003e 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/test_events.py
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/test_filter.py
index 475c6bed3d..475c6bed3d 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/test_filter.py
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
new file mode 100644
index 0000000000..d7fa635eae
--- /dev/null
+++ b/tests/rest/client/test_keys.py
@@ -0,0 +1,91 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License
+
+from http import HTTPStatus
+
+from synapse.api.errors import Codes
+from synapse.rest import admin
+from synapse.rest.client import keys, login
+
+from tests import unittest
+
+
+class KeyQueryTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        keys.register_servlets,
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+    ]
+
+    def test_rejects_device_id_ice_key_outside_of_list(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: "device_id1",
+                },
+            },
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_rejects_device_key_given_as_map_to_bool(self):
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        bob = self.register_user("bob", "uncle")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    bob: {
+                        "device_id1": True,
+                    },
+                },
+            },
+            alice_token,
+        )
+
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
+
+    def test_requires_device_key(self):
+        """`device_keys` is required. We should complain if it's missing."""
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {},
+            alice_token,
+        )
+        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(
+            channel.json_body["errcode"],
+            Codes.BAD_JSON,
+            channel.result,
+        )
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/test_login.py
index eba3552b19..5b2243fe52 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -32,7 +32,7 @@ from synapse.types import create_requester
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
 
diff --git a/tests/rest/client/v2_alpha/test_password_policy.py b/tests/rest/client/test_password_policy.py
index 3cf5871899..3cf5871899 100644
--- a/tests/rest/client/v2_alpha/test_password_policy.py
+++ b/tests/rest/client/test_password_policy.py
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/test_presence.py
index 1d152352d1..1d152352d1 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/test_presence.py
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/test_profile.py
index 2860579c2e..2860579c2e 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/test_profile.py
diff --git a/tests/rest/client/v1/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py
index d0ce91ccd9..d0ce91ccd9 100644
--- a/tests/rest/client/v1/test_push_rule_attrs.py
+++ b/tests/rest/client/test_push_rule_attrs.py
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/test_register.py
index fecda037a5..9f3ab2c985 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -24,6 +24,7 @@ from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
 from synapse.api.errors import Codes
 from synapse.appservice import ApplicationService
 from synapse.rest.client import account, account_validity, login, logout, register, sync
+from synapse.storage._base import db_to_json
 
 from tests import unittest
 from tests.unittest import override_config
@@ -204,6 +205,371 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_requires_token(self):
+        username = "kermit"
+        device_id = "frogfone"
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params = {
+            "username": username,
+            "password": "monkey",
+            "device_id": device_id,
+        }
+
+        # Request without auth to get flows and session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+        # Synapse adds a dummy stage to differentiate flows where otherwise one
+        # flow would be a subset of another flow.
+        self.assertCountEqual(
+            [[LoginType.REGISTRATION_TOKEN, LoginType.DUMMY]],
+            (f["stages"] for f in flows),
+        )
+        session = channel.json_body["session"]
+
+        # Do the registration token stage and check it has completed
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+        # Do the m.login.dummy stage and check registration was successful
+        params["auth"] = {
+            "type": LoginType.DUMMY,
+            "session": session,
+        }
+        request_data = json.dumps(params)
+        channel = self.make_request(b"POST", self.url, request_data)
+        det_data = {
+            "user_id": f"@{username}:{self.hs.hostname}",
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+
+        # Check the `completed` counter has been incremented and pending is 0
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["completed"], 1)
+        self.assertEquals(res["pending"], 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_invalid(self):
+        params = {
+            "username": "kermit",
+            "password": "monkey",
+        }
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Test with token param missing (invalid)
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.MISSING_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with non-string (invalid)
+        params["auth"]["token"] = 1234
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.INVALID_PARAM)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Test with unknown token (invalid)
+        params["auth"]["token"] = "1234"
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_limit_uses(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        # Create token that can be used once
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": 1,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        # Do 2 requests without auth to get two session IDs
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with session1 and check `pending` is 1
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Repeat request to make sure pending isn't increased again
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 1)
+
+        # Check auth fails when using token with session2
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+        # Check pending=0 and completed=1
+        res = self.get_success(
+            store.db_pool.simple_select_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcols=["pending", "completed"],
+            )
+        )
+        self.assertEquals(res["pending"], 0)
+        self.assertEquals(res["completed"], 1)
+
+        # Check auth still fails when using token with session2
+        channel = self.make_request(b"POST", self.url, json.dumps(params2))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_expiry(self):
+        token = "abcd"
+        now = self.hs.get_clock().time_msec()
+        store = self.hs.get_datastore()
+        # Create token that expired yesterday
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": now - 24 * 60 * 60 * 1000,
+                },
+            )
+        )
+        params = {"username": "kermit", "password": "monkey"}
+        # Request without auth to get session
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Check authentication fails with expired token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        self.assertEquals(channel.json_body["errcode"], Codes.UNAUTHORIZED)
+        self.assertEquals(channel.json_body["completed"], [])
+
+        # Update token so it expires tomorrow
+        self.get_success(
+            store.db_pool.simple_update_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+                updatevalues={"expiry_time": now + 24 * 60 * 60 * 1000},
+            )
+        )
+
+        # Check authentication succeeds
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        completed = channel.json_body["completed"]
+        self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry(self):
+        """Test `pending` is decremented when an uncompleted session expires."""
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do 2 requests without auth to get two session IDs
+        params1 = {"username": "bert", "password": "monkey"}
+        params2 = {"username": "ernie", "password": "monkey"}
+        channel1 = self.make_request(b"POST", self.url, json.dumps(params1))
+        session1 = channel1.json_body["session"]
+        channel2 = self.make_request(b"POST", self.url, json.dumps(params2))
+        session2 = channel2.json_body["session"]
+
+        # Use token with both sessions
+        params1["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session1,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        params2["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session2,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params2))
+
+        # Complete registration with session1
+        params1["auth"]["type"] = LoginType.DUMMY
+        self.make_request(b"POST", self.url, json.dumps(params1))
+
+        # Check `result` of registration token stage for session1 is `True`
+        result1 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session1,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertTrue(db_to_json(result1))
+
+        # Check `result` for session2 is the token used
+        result2 = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "ui_auth_sessions_credentials",
+                keyvalues={
+                    "session_id": session2,
+                    "stage_type": LoginType.REGISTRATION_TOKEN,
+                },
+                retcol="result",
+            )
+        )
+        self.assertEquals(db_to_json(result2), token)
+
+        # Delete both sessions (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
+        # Check pending is now 0
+        pending = self.get_success(
+            store.db_pool.simple_select_one_onecol(
+                "registration_tokens",
+                keyvalues={"token": token},
+                retcol="pending",
+            )
+        )
+        self.assertEquals(pending, 0)
+
+    @override_config({"registration_requires_token": True})
+    def test_POST_registration_token_session_expiry_deleted_token(self):
+        """Test session expiry doesn't break when the token is deleted.
+
+        1. Start but don't complete UIA with a registration token
+        2. Delete the token from the database
+        3. Expire the session
+        """
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        # Do request without auth to get a session ID
+        params = {"username": "kermit", "password": "monkey"}
+        channel = self.make_request(b"POST", self.url, json.dumps(params))
+        session = channel.json_body["session"]
+
+        # Use token
+        params["auth"] = {
+            "type": LoginType.REGISTRATION_TOKEN,
+            "token": token,
+            "session": session,
+        }
+        self.make_request(b"POST", self.url, json.dumps(params))
+
+        # Delete token
+        self.get_success(
+            store.db_pool.simple_delete_one(
+                "registration_tokens",
+                keyvalues={"token": token},
+            )
+        )
+
+        # Delete session (mimics expiry)
+        self.get_success(
+            store.delete_old_ui_auth_sessions(self.hs.get_clock().time_msec())
+        )
+
     def test_advertised_flows(self):
         channel = self.make_request(b"POST", self.url, b"{}")
         self.assertEquals(channel.result["code"], b"401", channel.result)
@@ -744,3 +1110,71 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
 
         self.assertGreaterEqual(res, now_ms + self.validity_period - self.max_delta)
         self.assertLessEqual(res, now_ms + self.validity_period)
+
+
+class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
+    servlets = [register.register_servlets]
+    url = "/_matrix/client/unstable/org.matrix.msc3231/register/org.matrix.msc3231.login.registration_token/validity"
+
+    def default_config(self):
+        config = super().default_config()
+        config["registration_requires_token"] = True
+        return config
+
+    def test_GET_token_valid(self):
+        token = "abcd"
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.db_pool.simple_insert(
+                "registration_tokens",
+                {
+                    "token": token,
+                    "uses_allowed": None,
+                    "pending": 0,
+                    "completed": 0,
+                    "expiry_time": None,
+                },
+            )
+        )
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], True)
+
+    def test_GET_token_invalid(self):
+        token = "1234"
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertEquals(channel.json_body["valid"], False)
+
+    @override_config(
+        {"rc_registration_token_validity": {"per_second": 0.1, "burst_count": 5}}
+    )
+    def test_GET_ratelimiting(self):
+        token = "1234"
+
+        for i in range(0, 6):
+            channel = self.make_request(
+                b"GET",
+                f"{self.url}?token={token}",
+            )
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
+
+        channel = self.make_request(
+            b"GET",
+            f"{self.url}?token={token}",
+        )
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..02b5e9a8d0 100644
--- a/tests/rest/client/v2_alpha/test_relations.py
+++ b/tests/rest/client/test_relations.py
diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/test_report_event.py
index ee6b0b9ebf..ee6b0b9ebf 100644
--- a/tests/rest/client/v2_alpha/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/test_rooms.py
index 0c9cbb9aff..0c9cbb9aff 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sendtodevice.py b/tests/rest/client/test_sendtodevice.py
index 6db7062a8e..6db7062a8e 100644
--- a/tests/rest/client/v2_alpha/test_sendtodevice.py
+++ b/tests/rest/client/test_sendtodevice.py
diff --git a/tests/rest/client/v2_alpha/test_shared_rooms.py b/tests/rest/client/test_shared_rooms.py
index 283eccd53f..283eccd53f 100644
--- a/tests/rest/client/v2_alpha/test_shared_rooms.py
+++ b/tests/rest/client/test_shared_rooms.py
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/test_sync.py
index 95be369d4b..95be369d4b 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/test_sync.py
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/test_typing.py
index b54b004733..b54b004733 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/test_typing.py
diff --git a/tests/rest/client/v2_alpha/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py
index 72f976d8e2..72f976d8e2 100644
--- a/tests/rest/client/v2_alpha/test_upgrade_room.py
+++ b/tests/rest/client/test_upgrade_room.py
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/utils.py
index 954ad1a1fd..954ad1a1fd 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/utils.py
diff --git a/tests/rest/client/v1/__init__.py b/tests/rest/client/v1/__init__.py
deleted file mode 100644
index 5e83dba2ed..0000000000
--- a/tests/rest/client/v1/__init__.py
+++ /dev/null
@@ -1,13 +0,0 @@
-# Copyright 2014-2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tests/rest/client/v2_alpha/__init__.py
+++ /dev/null
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 3785799f46..348fcb72a7 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -85,11 +85,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         # Send the join, it should return None (which is not an error)
         self.assertEqual(
-            self.get_success(
-                self.handler.on_receive_pdu(
-                    "test.serv", join_event, sent_to_us_directly=True
-                )
-            ),
+            self.get_success(self.handler.on_receive_pdu("test.serv", join_event)),
             None,
         )
 
@@ -135,9 +131,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         with LoggingContext("test-context"):
             failure = self.get_failure(
-                self.handler.on_receive_pdu(
-                    "test.serv", lying_event, sent_to_us_directly=True
-                ),
+                self.handler.on_receive_pdu("test.serv", lying_event),
                 FederationError,
             )
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 3eec9c4d5b..f2c90cc47b 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -252,7 +252,7 @@ class HomeserverTestCase(TestCase):
             reactor=self.reactor,
         )
 
-        from tests.rest.client.v1.utils import RestHelper
+        from tests.rest.client.utils import RestHelper
 
         self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))