summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v2_alpha/test_register.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1c128e81f5..36eaabbad8 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -174,3 +174,76 @@ class RegisterRestServletTestCase(unittest.TestCase):
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+    def test_POST_terms_auth(self):
+        self.hs.config.block_events_without_consent_error = True
+        self.hs.config.public_baseurl = "https://example.org"
+        self.hs.config.user_consent_version = "1.0"
+
+        # Do a UI auth request
+        request, channel = make_request(b"POST", self.url, b"{}")
+        render(request, self.resource, self.clock)
+
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+
+        self.assertIsInstance(channel.json_body["session"], str)
+
+        self.assertIsInstance(channel.json_body["flows"], list)
+        for flow in channel.json_body["flows"]:
+            self.assertIsInstance(flow["stages"], list)
+            self.assertTrue(len(flow["stages"]) > 0)
+            self.assertEquals(flow["stages"][-1], "m.login.terms")
+
+        expected_params = {
+            "m.login.terms": {
+                "policies": {
+                    "privacy_policy": {
+                        "en": {
+                            "name": "Privacy Policy",
+                            "url": "https://example.org/_matrix/consent",
+                        },
+                        "version": "1.0"
+                    },
+                },
+            },
+        }
+        self.assertIsInstance(channel.json_body["params"], dict)
+        self.assertDictContainsSubset(channel.json_body["params"], expected_params)
+
+        # Completing the stage should result in the stage being completed
+
+        user_id = "@kermit:muppet"
+        token = "kermits_access_token"
+        device_id = "frogfone"
+        request_data = json.dumps(
+            {
+                "username": "kermit",
+                "password": "monkey",
+                "device_id": device_id,
+                "auth": {
+                    "session": channel.json_body["session"],
+                    "type": "m.login.terms",
+                },
+            }
+        )
+        self.registration_handler.check_username = Mock(return_value=True)
+        self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
+        self.registration_handler.register = Mock(return_value=(user_id, None))
+        self.auth_handler.get_access_token_for_user_id = Mock(return_value=token)
+        self.device_handler.check_device_registered = Mock(return_value=device_id)
+
+
+        request, channel = make_request(b"POST", self.url, request_data)
+        render(request, self.resource, self.clock)
+
+        det_data = {
+            "user_id": user_id,
+            "access_token": token,
+            "home_server": self.hs.hostname,
+            "device_id": device_id,
+        }
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+        self.assertDictContainsSubset(det_data, channel.json_body)
+        self.auth_handler.get_login_tuple_for_user_id(
+            user_id, device_id=device_id, initial_device_display_name=None
+        )