summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py26
-rw-r--r--tests/rest/client/v2_alpha/test_register.py84
2 files changed, 88 insertions, 22 deletions
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index b9ef46e8fb..b6df1396ad 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -18,11 +18,22 @@ from twisted.internet.defer import succeed
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType
+from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client.v2_alpha import auth, register
 
 from tests import unittest
 
 
+class DummyRecaptchaChecker(UserInteractiveAuthChecker):
+    def __init__(self, hs):
+        super().__init__(hs)
+        self.recaptcha_attempts = []
+
+    def check_auth(self, authdict, clientip):
+        self.recaptcha_attempts.append((authdict, clientip))
+        return succeed(True)
+
+
 class FallbackAuthTests(unittest.HomeserverTestCase):
 
     servlets = [
@@ -44,15 +55,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, hs):
+        self.recaptcha_checker = DummyRecaptchaChecker(hs)
         auth_handler = hs.get_auth_handler()
-
-        self.recaptcha_attempts = []
-
-        def _recaptcha(authdict, clientip):
-            self.recaptcha_attempts.append((authdict, clientip))
-            return succeed(True)
-
-        auth_handler.checkers[LoginType.RECAPTCHA] = _recaptcha
+        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
 
     @unittest.INFO
     def test_fallback_captcha(self):
@@ -89,8 +94,9 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
         self.assertEqual(request.code, 200)
 
         # The recaptcha handler is called with the response given
-        self.assertEqual(len(self.recaptcha_attempts), 1)
-        self.assertEqual(self.recaptcha_attempts[0][0]["response"], "a")
+        attempts = self.recaptcha_checker.recaptcha_attempts
+        self.assertEqual(len(attempts), 1)
+        self.assertEqual(attempts[0][0]["response"], "a")
 
         # also complete the dummy auth
         request, channel = self.make_request(
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ab4d7d70d0..dab87e5edf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -34,19 +34,12 @@ from tests import unittest
 class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     servlets = [register.register_servlets]
+    url = b"/_matrix/client/r0/register"
 
-    def make_homeserver(self, reactor, clock):
-
-        self.url = b"/_matrix/client/r0/register"
-
-        self.hs = self.setup_test_homeserver()
-        self.hs.config.enable_registration = True
-        self.hs.config.registrations_require_3pid = []
-        self.hs.config.auto_join_rooms = []
-        self.hs.config.enable_registration_captcha = False
-        self.hs.config.allow_guest_access = True
-
-        return self.hs
+    def default_config(self, name="test"):
+        config = super().default_config(name)
+        config["allow_guest_access"] = True
+        return config
 
     def test_POST_appservice_registration_valid(self):
         user_id = "@as_user_kermit:test"
@@ -199,6 +192,73 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
+    def test_advertised_flows(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we only expect the dummy flow
+        self.assertCountEqual([["m.login.dummy"]], (f["stages"] for f in flows))
+
+    @unittest.override_config(
+        {
+            "enable_registration_captcha": True,
+            "user_consent": {
+                "version": "1",
+                "template_dir": "/",
+                "require_at_registration": True,
+            },
+            "account_threepid_delegates": {
+                "email": "https://id_server",
+                "msisdn": "https://id_server",
+            },
+        }
+    )
+    def test_advertised_flows_captcha_and_terms_and_3pids(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        self.assertCountEqual(
+            [
+                ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+                ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+                [
+                    "m.login.recaptcha",
+                    "m.login.terms",
+                    "m.login.msisdn",
+                    "m.login.email.identity",
+                ],
+            ],
+            (f["stages"] for f in flows),
+        )
+
+    @unittest.override_config(
+        {
+            "public_baseurl": "https://test_server",
+            "registrations_require_3pid": ["email"],
+            "disable_msisdn_registration": True,
+            "email": {
+                "smtp_host": "mail_server",
+                "smtp_port": 2525,
+                "notif_from": "sender@host",
+            },
+        }
+    )
+    def test_advertised_flows_no_msisdn_email_required(self):
+        request, channel = self.make_request(b"POST", self.url, b"{}")
+        self.render(request)
+        self.assertEquals(channel.result["code"], b"401", channel.result)
+        flows = channel.json_body["flows"]
+
+        # with the stock config, we expect all four combinations of 3pid
+        self.assertCountEqual(
+            [["m.login.email.identity"]], (f["stages"] for f in flows)
+        )
+
 
 class AccountValidityTestCase(unittest.HomeserverTestCase):