summary refs log tree commit diff
path: root/tests/storage/test_registration.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_registration.py')
-rw-r--r--tests/storage/test_registration.py135
1 files changed, 127 insertions, 8 deletions
diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py

index a49ac1525e..05ea802008 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py
@@ -11,15 +11,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError +from synapse.server import HomeServer +from synapse.types import JsonDict, UserID +from synapse.util import Clock -from tests.unittest import HomeserverTestCase +from tests.unittest import HomeserverTestCase, override_config class RegistrationStoreTestCase(HomeserverTestCase): - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.user_id = "@my-user:test" @@ -27,7 +31,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.pwhash = "{xx1}123456789" self.device_id = "akgjhdjklgshg" - def test_register(self): + def test_register(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( @@ -38,17 +42,32 @@ class RegistrationStoreTestCase(HomeserverTestCase): "admin": 0, "is_guest": 0, "consent_version": None, + "consent_ts": None, "consent_server_notice_sent": None, "appservice_id": None, "creation_ts": 0, "user_type": None, "deactivated": 0, "shadow_banned": 0, + "approved": 1, }, (self.get_success(self.store.get_user_by_id(self.user_id))), ) - def test_add_tokens(self): + def test_consent(self) -> None: + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + before_consent = self.clock.time_msec() + self.reactor.advance(5) + self.get_success(self.store.user_set_consent_version(self.user_id, "1")) + self.reactor.advance(5) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user + self.assertEqual(user["consent_version"], "1") + self.assertGreater(user["consent_ts"], before_consent) + self.assertLess(user["consent_ts"], self.clock.time_msec()) + + def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( self.store.add_access_token_to_user( @@ -58,11 +77,12 @@ class RegistrationStoreTestCase(HomeserverTestCase): result = self.get_success(self.store.get_user_by_access_token(self.tokens[1])) + assert result self.assertEqual(result.user_id, self.user_id) self.assertEqual(result.device_id, self.device_id) self.assertIsNotNone(result.token_id) - def test_user_delete_access_tokens(self): + def test_user_delete_access_tokens(self) -> None: # add some tokens self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.get_success( @@ -87,6 +107,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): # check the one not associated with the device was not deleted user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) + assert user self.assertEqual(self.user_id, user.user_id) # now delete the rest @@ -95,11 +116,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_access_token(self.tokens[0])) self.assertIsNone(user, "access token was not deleted without device_id") - def test_is_support_user(self): + def test_is_support_user(self) -> None: TEST_USER = "@test:test" SUPPORT_USER = "@support:test" - res = self.get_success(self.store.is_support_user(None)) + res = self.get_success(self.store.is_support_user(None)) # type: ignore[arg-type] self.assertFalse(res) self.get_success( self.store.register_user(user_id=TEST_USER, password_hash=None) @@ -115,7 +136,7 @@ class RegistrationStoreTestCase(HomeserverTestCase): res = self.get_success(self.store.is_support_user(SUPPORT_USER)) self.assertTrue(res) - def test_3pid_inhibit_invalid_validation_session_error(self): + def test_3pid_inhibit_invalid_validation_session_error(self) -> None: """Tests that enabling the configuration option to inhibit 3PID errors on /requestToken also inhibits validation errors caused by an unknown session ID. """ @@ -147,3 +168,101 @@ class RegistrationStoreTestCase(HomeserverTestCase): ThreepidValidationError, ) self.assertEqual(e.value.msg, "Validation token not found or has expired", e) + + +class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): + def default_config(self) -> JsonDict: + config = super().default_config() + + # If there's already some config for this feature in the default config, it + # means we're overriding it with @override_config. In this case we don't want + # to do anything more with it. + msc3866_config = config.get("experimental_features", {}).get("msc3866") + if msc3866_config is not None: + return config + + # Require approval for all new accounts. + config["experimental_features"] = { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": True, + } + } + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.user_id = "@my-user:test" + self.pwhash = "{xx1}123456789" + + @override_config( + { + "experimental_features": { + "msc3866": { + "enabled": True, + "require_approval_for_new_accounts": False, + } + } + } + ) + def test_approval_not_required(self) -> None: + """Tests that if we don't require approval for new accounts, newly created + accounts are automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertTrue(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approval_required(self) -> None: + """Tests that if we require approval for new accounts, newly created accounts + are not automatically marked as approved. + """ + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + assert user is not None + self.assertFalse(user["approved"]) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + def test_override(self) -> None: + """Tests that if we require approval for new accounts, but we explicitly say the + new user should be considered approved, they're marked as approved. + """ + self.get_success( + self.store.register_user( + self.user_id, + self.pwhash, + approved=True, + ) + ) + + user = self.get_success(self.store.get_user_by_id(self.user_id)) + self.assertIsNotNone(user) + assert user is not None + self.assertEqual(user["approved"], 1) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved) + + def test_approve_user(self) -> None: + """Tests that approving the user updates their approval status.""" + self.get_success(self.store.register_user(self.user_id, self.pwhash)) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertFalse(approved) + + self.get_success( + self.store.update_user_approval_status( + UserID.from_string(self.user_id), True + ) + ) + + approved = self.get_success(self.store.is_user_approved(self.user_id)) + self.assertTrue(approved)