diff --git a/changelog.d/6106.misc b/changelog.d/6106.misc
new file mode 100644
index 0000000000..d732091779
--- /dev/null
+++ b/changelog.d/6106.misc
@@ -0,0 +1 @@
+Refactor code for calculating registration flows.
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 135a70808f..e3f3d9126f 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -16,6 +16,7 @@
import hmac
import logging
+from typing import List, Union
from six import string_types
@@ -31,8 +32,11 @@ from synapse.api.errors import (
ThreepidValidationError,
UnrecognizedRequestError,
)
+from synapse.config.captcha import CaptchaConfig
+from synapse.config.consent_config import ConsentConfig
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.http.server import finish_request
from synapse.http.servlet import (
@@ -371,6 +375,8 @@ class RegisterRestServlet(RestServlet):
self.ratelimiter = hs.get_registration_ratelimiter()
self.clock = hs.get_clock()
+ self._registration_flows = _calculate_registration_flows(hs.config)
+
@interactive_auth_handler
@defer.inlineCallbacks
def on_POST(self, request):
@@ -491,69 +497,8 @@ class RegisterRestServlet(RestServlet):
assigned_user_id=registered_user_id,
)
- # FIXME: need a better error than "no auth flow found" for scenarios
- # where we required 3PID for registration but the user didn't give one
- require_email = "email" in self.hs.config.registrations_require_3pid
- require_msisdn = "msisdn" in self.hs.config.registrations_require_3pid
-
- show_msisdn = True
- if self.hs.config.disable_msisdn_registration:
- show_msisdn = False
- require_msisdn = False
-
- flows = []
- if self.hs.config.enable_registration_captcha:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- # Also add a dummy flow here, otherwise if a client completes
- # recaptcha first we'll assume they were going for this flow
- # and complete the request, when they could have been trying to
- # complete one of the flows with email/msisdn auth.
- flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email:
- flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend(
- [[LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY]]
- )
- else:
- # only support 3PIDless registration if no 3PIDs are required
- if not require_email and not require_msisdn:
- flows.extend([[LoginType.DUMMY]])
- # only support the email-only flow if we don't require MSISDN 3PIDs
- if not require_msisdn:
- flows.extend([[LoginType.EMAIL_IDENTITY]])
-
- if show_msisdn:
- # only support the MSISDN-only flow if we don't require email 3PIDs
- if not require_email or require_msisdn:
- flows.extend([[LoginType.MSISDN]])
- # always let users provide both MSISDN & email
- flows.extend([[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]])
-
- # Append m.login.terms to all flows if we're requiring consent
- if self.hs.config.user_consent_at_registration:
- new_flows = []
- for flow in flows:
- inserted = False
- # m.login.terms should go near the end but before msisdn or email auth
- for i, stage in enumerate(flow):
- if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN:
- flow.insert(i, LoginType.TERMS)
- inserted = True
- break
- if not inserted:
- flow.append(LoginType.TERMS)
- flows.extend(new_flows)
-
auth_result, params, session_id = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
+ self._registration_flows, body, self.hs.get_ip_from_request(request)
)
# Check that we're not trying to register a denied 3pid.
@@ -716,6 +661,61 @@ class RegisterRestServlet(RestServlet):
)
+def _calculate_registration_flows(
+ # technically `config` has to provide *all* of these interfaces, not just one
+ config: Union[RegistrationConfig, ConsentConfig, CaptchaConfig],
+) -> List[List[str]]:
+ """Get a suitable flows list for registration
+
+ Args:
+ config: server configuration
+
+ Returns: a list of supported flows
+ """
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = "email" in config.registrations_require_3pid
+ require_msisdn = "msisdn" in config.registrations_require_3pid
+
+ show_msisdn = True
+ if config.disable_msisdn_registration:
+ show_msisdn = False
+ require_msisdn = False
+
+ flows = []
+
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ # Add a dummy step here, otherwise if a client completes
+ # recaptcha first we'll assume they were going for this flow
+ # and complete the request, when they could have been trying to
+ # complete one of the flows with email/msisdn auth.
+ flows.append([LoginType.DUMMY])
+
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.append([LoginType.EMAIL_IDENTITY])
+
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if show_msisdn and not require_email:
+ flows.append([LoginType.MSISDN])
+
+ if show_msisdn:
+ flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY])
+
+ # Prepend m.login.terms to all flows if we're requiring consent
+ if config.user_consent_at_registration:
+ for flow in flows:
+ flow.insert(0, LoginType.TERMS)
+
+ # Prepend recaptcha to all flows if we're requiring captcha
+ if config.enable_registration_captcha:
+ for flow in flows:
+ flow.insert(0, LoginType.RECAPTCHA)
+
+ return flows
+
+
def register_servlets(hs, http_server):
EmailRegisterRequestTokenRestServlet(hs).register(http_server)
MsisdnRegisterRequestTokenRestServlet(hs).register(http_server)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index ab4d7d70d0..bc2dc47973 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -34,19 +34,12 @@ from tests import unittest
class RegisterRestServletTestCase(unittest.HomeserverTestCase):
servlets = [register.register_servlets]
+ url = b"/_matrix/client/r0/register"
- def make_homeserver(self, reactor, clock):
-
- self.url = b"/_matrix/client/r0/register"
-
- self.hs = self.setup_test_homeserver()
- self.hs.config.enable_registration = True
- self.hs.config.registrations_require_3pid = []
- self.hs.config.auto_join_rooms = []
- self.hs.config.enable_registration_captcha = False
- self.hs.config.allow_guest_access = True
-
- return self.hs
+ def default_config(self, name="test"):
+ config = super().default_config(name)
+ config["allow_guest_access"] = True
+ return config
def test_POST_appservice_registration_valid(self):
user_id = "@as_user_kermit:test"
@@ -199,6 +192,68 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
+ def test_advertised_flows(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we expect all four combinations of 3pid
+ self.assertCountEqual(
+ [
+ ["m.login.dummy"],
+ ["m.login.email.identity"],
+ ["m.login.msisdn"],
+ ["m.login.msisdn", "m.login.email.identity"],
+ ],
+ (f["stages"] for f in flows),
+ )
+
+ @unittest.override_config(
+ {
+ "enable_registration_captcha": True,
+ "user_consent": {
+ "version": "1",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ }
+ )
+ def test_advertised_flows_captcha_and_terms(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ self.assertCountEqual(
+ [
+ ["m.login.recaptcha", "m.login.terms", "m.login.dummy"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.email.identity"],
+ ["m.login.recaptcha", "m.login.terms", "m.login.msisdn"],
+ [
+ "m.login.recaptcha",
+ "m.login.terms",
+ "m.login.msisdn",
+ "m.login.email.identity",
+ ],
+ ],
+ (f["stages"] for f in flows),
+ )
+
+ @unittest.override_config(
+ {"registrations_require_3pid": ["email"], "disable_msisdn_registration": True}
+ )
+ def test_advertised_flows_no_msisdn_email_required(self):
+ request, channel = self.make_request(b"POST", self.url, b"{}")
+ self.render(request)
+ self.assertEquals(channel.result["code"], b"401", channel.result)
+ flows = channel.json_body["flows"]
+
+ # with the stock config, we expect all four combinations of 3pid
+ self.assertCountEqual(
+ [["m.login.email.identity"]], (f["stages"] for f in flows)
+ )
+
class AccountValidityTestCase(unittest.HomeserverTestCase):
diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py
index 52739fbabc..5ec5d2b358 100644
--- a/tests/test_terms_auth.py
+++ b/tests/test_terms_auth.py
@@ -28,6 +28,21 @@ from tests import unittest
class TermsTestCase(unittest.HomeserverTestCase):
servlets = [register_servlets]
+ def default_config(self, name="test"):
+ config = super().default_config(name)
+ config.update(
+ {
+ "public_baseurl": "https://example.org/",
+ "user_consent": {
+ "version": "1.0",
+ "policy_name": "My Cool Privacy Policy",
+ "template_dir": "/",
+ "require_at_registration": True,
+ },
+ }
+ )
+ return config
+
def prepare(self, reactor, clock, hs):
self.clock = MemoryReactorClock()
self.hs_clock = Clock(self.clock)
@@ -35,17 +50,8 @@ class TermsTestCase(unittest.HomeserverTestCase):
self.registration_handler = Mock()
self.auth_handler = Mock()
self.device_handler = Mock()
- hs.config.enable_registration = True
- hs.config.registrations_require_3pid = []
- hs.config.auto_join_rooms = []
- hs.config.enable_registration_captcha = False
def test_ui_auth(self):
- self.hs.config.user_consent_at_registration = True
- self.hs.config.user_consent_policy_name = "My Cool Privacy Policy"
- self.hs.config.public_baseurl = "https://example.org/"
- self.hs.config.user_consent_version = "1.0"
-
# Do a UI auth request
request, channel = self.make_request(b"POST", self.url, b"{}")
self.render(request)
|