diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..a802e1a406 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -174,3 +174,73 @@ class RegisterRestServletTestCase(unittest.TestCase):
self.assertEquals(channel.result["code"], b"403", channel.result)
self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+ def test_POST_terms_auth(self):
+ self.hs.config.block_events_without_consent_error = True
+ self.hs.config.public_baseurl = "https://example.org"
+ self.hs.config.user_consent_version = "1.0"
+
+ # Do a UI auth request
+ reqest, channel = make_request(b"POST", self.url, b"{}")
+ render(request, self.resource, self.clock)
+
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+
+ self.assertIsInstance(channel.json_body["session"], str)
+
+ self.assertIsInstance(channel.json_body["flows"], list)
+ for flow in channel.json_body["flows"]:
+ self.assertIsInstance(flow["stages"], list)
+ self.assertTrue(len(flow["stages"]) > 0)
+ self.assertEquals(flow["stages"][-1], "m.login.terms")
+
+ expected_params = {
+ "m.login.terms": {
+ "policies": {
+ "privacy_policy": {
+ "en": {
+ "name": "Privacy Policy",
+ "url": "https://example.org/_matrix/consent",
+ },
+ "version": "1.0"
+ },
+ },
+ },
+ }
+ self.assertIsInstance(channel.json_body["params"], dict)
+ self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+
+ # Completing the stage should result in the stage being completed
+
+ user_id = "@kermit:muppet"
+ token = "kermits_access_token"
+ device_id = "frogfone"
+ request_data = json.dumps(
+ {
+ "username": "kermit",
+ "password": "monkey",
+ "device_id": device_id,
+ "session": channel.json_body["session"],
+ }
+ )
+ self.registration_handler.check_username = Mock(return_value=True)
+ self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
+ self.registration_handler.register = Mock(return_value=(user_id, None))
+ self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+ self.device_handler.check_device_registered = Mock(return_value=device_id)
+
+
+ request, channel = make_request(b"POST", self.url, request_data)
+ render(request, self.resource, self.clock)
+
+ det_data = {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ "device_id": device_id,
+ }
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+ self.assertDictContainsSubset(det_data, channel.json_body)
+ self.auth_handler.get_login_tuple_for_user_id(
+ user_id, device_id=device_id, initial_device_display_name=None
+ )
|