diff options
Diffstat (limited to 'tests/rest/client')
-rw-r--r-- | tests/rest/client/test_auth.py | 33 | ||||
-rw-r--r-- | tests/rest/client/test_login.py | 41 | ||||
-rw-r--r-- | tests/rest/client/test_register.py | 32 | ||||
-rw-r--r-- | tests/rest/client/utils.py | 12 |
4 files changed, 113 insertions, 5 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py index 05355c7fb6..090cef5216 100644 --- a/tests/rest/client/test_auth.py +++ b/tests/rest/client/test_auth.py @@ -20,7 +20,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin -from synapse.api.constants import LoginType +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.rest.client import account, auth, devices, login, logout, register from synapse.rest.synapse.client import build_synapse_client_resource_tree @@ -567,6 +568,36 @@ class UIAuthTests(unittest.HomeserverTestCase): body={"auth": {"session": session_id}}, ) + @skip_unless(HAS_OIDC, "requires OIDC") + @override_config( + { + "oidc_config": TEST_OIDC_CONFIG, + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + }, + } + ) + def test_sso_not_approved(self) -> None: + """Tests that if we register a user via SSO while requiring approval for new + accounts, we still raise the correct error before logging the user in. + """ + login_resp = self.helper.login_via_oidc("username", expected_status=403) + + self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL) + self.assertEqual( + ApprovalNoticeMedium.NONE, login_resp["approval_notice_medium"] + ) + + # Check that we didn't register a device for the user during the login attempt. + devices = self.get_success( + self.hs.get_datastores().main.get_devices_by_user("@username:test") + ) + + self.assertEqual(len(devices), 0) + class RefreshAuthTests(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index e2a4d98275..e801ba8c8b 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -23,6 +23,8 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource import synapse.rest.admin +from synapse.api.constants import ApprovalNoticeMedium, LoginType +from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet @@ -94,6 +96,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): logout.register_servlets, devices.register_servlets, lambda hs, http_server: WhoamiRestServlet(hs).register(http_server), + register.register_servlets, ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -406,6 +409,44 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400) self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + + params = { + "type": LoginType.PASSWORD, + "identifier": {"type": "m.id.user", "user": "kermit"}, + "password": "monkey", + } + channel = self.make_request("POST", LOGIN_URL, params) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b781875d52..11cf3939d8 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -22,7 +22,11 @@ import pkg_resources from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType +from synapse.api.constants import ( + APP_SERVICE_REGISTRATION_TYPE, + ApprovalNoticeMedium, + LoginType, +) from synapse.api.errors import Codes from synapse.appservice import ApplicationService from synapse.rest.client import account, account_validity, login, logout, register, sync @@ -765,6 +769,32 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.json_body) self.assertEqual(channel.json_body["errcode"], Codes.USER_IN_USE) + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + } + ) + def test_require_approval(self) -> None: + channel = self.make_request( + "POST", + "register", + { + "username": "kermit", + "password": "monkey", + "auth": {"type": LoginType.DUMMY}, + }, + ) + self.assertEqual(403, channel.code, channel.result) + self.assertEqual(Codes.USER_AWAITING_APPROVAL, channel.json_body["errcode"]) + self.assertEqual( + ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] + ) + class AccountValidityTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index dd26145bf8..c249a42bb6 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -543,8 +543,12 @@ class RestHelper: return channel.json_body - def login_via_oidc(self, remote_user_id: str) -> JsonDict: - """Log in (as a new user) via OIDC + def login_via_oidc( + self, + remote_user_id: str, + expected_status: int = 200, + ) -> JsonDict: + """Log in via OIDC Returns the result of the final token login. @@ -578,7 +582,9 @@ class RestHelper: "/login", content={"type": "m.login.token", "token": login_token}, ) - assert channel.code == HTTPStatus.OK + assert ( + channel.code == expected_status + ), f"unexpected status in response: {channel.code}" return channel.json_body def auth_via_oidc( |