summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v2_alpha/test_register.py79
-rw-r--r--tests/test_terms_auth.py24
2 files changed, 82 insertions, 21 deletions
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ab4d7d70d0..bc2dc47973 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -34,19 +34,12 @@ from tests import unittest
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     servlets = [register.register_servlets]
+    url = b"/_matrix/client/r0/register"
 
-    def make_homeserver(self, reactor, clock):
-
-        self.url = b"/_matrix/client/r0/register"
-
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.enable_registration = True
-        self.hs.config.registrations_require_3pid = []
-        self.hs.config.auto_join_rooms = []
-        self.hs.config.enable_registration_captcha = False
-        self.hs.config.allow_guest_access = True
-
-        return self.hs
+    def default_config(self, name="test"):
+        config = super().default_config(name)
+        config["allow_guest_access"] = True
+        return config
 
     def test_POST_appservice_registration_valid(self):
         user_id = "@as_user_kermit:test"
@@ -199,6 +192,68 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    def test_advertised_flows(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we expect all four combinations of 3pid
+        self.assertCountEqual(
+            [
+                ["m.login.dummy"],
+                ["m.login.email.identity"],
+                ["m.login.msisdn"],
+                ["m.login.msisdn", "m.login.email.identity"],
+            ],
+            (f["stages"] for f in flows),
+        )
+
+    @unittest.override_config(
+        {
+            "enable_registration_captcha": True,
+            "user_consent": {
+                "version": "1",
+                "template_dir": "/",
+                "require_at_registration": True,
+            },
+        }
+    )
+    def test_advertised_flows_captcha_and_terms(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        self.assertCountEqual(
+            [
+                ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+                [
+                    "m.login.recaptcha",
+                    "m.login.terms",
+                    "m.login.msisdn",
+                    "m.login.email.identity",
+                ],
+            ],
+            (f["stages"] for f in flows),
+        )
+
+    @unittest.override_config(
+        {"registrations_require_3pid": ["email"], "disable_msisdn_registration": True}
+    )
+    def test_advertised_flows_no_msisdn_email_required(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we expect all four combinations of 3pid
+        self.assertCountEqual(
+            [["m.login.email.identity"]], (f["stages"] for f in flows)
+        )
+
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):
 
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52739fbabc..5ec5d2b358 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,6 +28,21 @@ from tests import unittest
 class TermsTestCase(unittest.HomeserverTestCase):
     servlets = [register_servlets]
 
+    def default_config(self, name="test"):
+        config = super().default_config(name)
+        config.update(
+            {
+                "public_baseurl": "https://example.org/",
+                "user_consent": {
+                    "version": "1.0",
+                    "policy_name": "My Cool Privacy Policy",
+                    "template_dir": "/",
+                    "require_at_registration": True,
+                },
+            }
+        )
+        return config
+
     def prepare(self, reactor, clock, hs):
         self.clock = MemoryReactorClock()
         self.hs_clock = Clock(self.clock)
@@ -35,17 +50,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
         self.registration_handler = Mock()
         self.auth_handler = Mock()
         self.device_handler = Mock()
-        hs.config.enable_registration = True
-        hs.config.registrations_require_3pid = []
-        hs.config.auto_join_rooms = []
-        hs.config.enable_registration_captcha = False
 
     def test_ui_auth(self):
-        self.hs.config.user_consent_at_registration = True
-        self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
-        self.hs.config.public_baseurl = "https://example.org/"
-        self.hs.config.user_consent_version = "1.0"
-
         # Do a UI auth request
         request, channel = self.make_request(b"POST", self.url, b"{}")
         self.render(request)