From 8dd2ea65a9566fd0850df0d989f700f61b490ed9 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Thu, 18 Mar 2021 17:54:08 +0100 Subject: Consistently check whether a password may be set for a user. (#9636) --- changelog.d/9636.bugfix | 1 + synapse/handlers/set_password.py | 2 +- synapse/rest/admin/users.py | 2 +- synapse/storage/databases/main/registration.py | 1 + tests/rest/admin/test_user.py | 173 +++++++++++++++++-------- 5 files changed, 122 insertions(+), 57 deletions(-) create mode 100644 changelog.d/9636.bugfix diff --git a/changelog.d/9636.bugfix b/changelog.d/9636.bugfix new file mode 100644 index 0000000000..fa772ed6fc --- /dev/null +++ b/changelog.d/9636.bugfix @@ -0,0 +1 @@ +Checks if passwords are allowed before setting it for the user. \ No newline at end of file diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 84af2dde7e..04e7c64c94 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler): logout_devices: bool, requester: Optional[Requester] = None, ) -> None: - if not self.hs.config.password_localdb_enabled: + if not self._auth_handler.can_change_password(): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) try: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2c89b62e25..aaa56a7024 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet): elif not deactivate and user["deactivated"]: if ( "password" not in body - and self.hs.config.password_localdb_enabled + and self.auth_handler.can_change_password() ): raise SynapseError( 400, "Must provide a password to re-activate an account." diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index eba66ff352..90a8f664ef 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): self._invalidate_cache_and_stream( txn, self.get_user_deactivated_status, (user_id,) ) + self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) txn.call_after(self.is_guest.invalidate, (user_id,)) @cached() diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e58d5cf0db..cf61f284cb 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() + self.auth_handler = hs.get_auth_handler() + # create users and get access tokens + # regardless of whether password login or SSO is allowed self.admin_user = self.register_user("admin", "pass", admin=True) - self.admin_user_tok = self.login("admin", "pass") + self.admin_user_tok = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.admin_user, device_id=None, valid_until_ms=None + ) + ) self.other_user = self.register_user("user", "pass", displayname="User") - self.other_user_token = self.login("user", "pass") + self.other_user_token = self.get_success( + self.auth_handler.get_access_token_for_user_id( + self.other_user, device_id=None, valid_until_ms=None + ) + ) self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote( self.other_user ) @@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(True, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) def test_create_user(self): @@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) # Get user @@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Bob's name", channel.json_body["displayname"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"]) - self.assertEqual(False, channel.json_body["admin"]) - self.assertEqual(False, channel.json_body["is_guest"]) - self.assertEqual(False, channel.json_body["deactivated"]) - self.assertEqual(False, channel.json_body["shadow_banned"]) + self.assertFalse(channel.json_body["admin"]) + self.assertFalse(channel.json_body["is_guest"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["shadow_banned"]) self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"]) @override_config( @@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0} @@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Admin user is not blocked by mau anymore self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@bob:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["admin"]) + self.assertFalse(channel.json_body["admin"]) @override_config( { @@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"]) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self.assertEqual(0, len(channel.json_body["threepids"])) self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"]) self.assertEqual("User", channel.json_body["displayname"]) @@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertTrue(profile["display_name"] == "User") # Deactivate user - body = json.dumps({"deactivated": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"deactivated": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) # Set new displayname user - body = json.dumps({"displayname": "Foobar"}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"displayname": "Foobar"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["deactivated"]) + self.assertTrue(channel.json_body["deactivated"]) self.assertEqual("Foobar", channel.json_body["displayname"]) # is not in user directory profile = self.get_success(self.store.get_user_in_directory(self.other_user)) - self.assertTrue(profile is None) + self.assertIsNone(profile) def test_reactivate_user(self): """ @@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Deactivate the user. + self._deactivate_user("@user:test") + + # Attempt to reactivate the user (without a password). + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) + self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + + # Reactivate the user. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": True}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNotNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) - d = self.store.mark_user_erased("@user:test") - self.assertIsNone(self.get_success(d)) - self._is_erased("@user:test", True) - # Attempt to reactivate the user (without a password). + @override_config({"password_config": {"localdb_enabled": False}}) + def test_reactivate_user_localdb_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False}).encode(encoding="utf_8"), + content={"deactivated": False, "password": "foo"}, ) - self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - # Reactivate the user. + # Reactivate the user without a password. channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=json.dumps({"deactivated": False, "password": "foo"}).encode( - encoding="utf_8" - ), + content={"deactivated": False}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual("@user:test", channel.json_body["name"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased("@user:test", False) - # Get user + @override_config({"password_config": {"enabled": False}}) + def test_reactivate_user_password_disabled(self): + """ + Test reactivating another user when using SSO. + """ + + # Deactivate the user. + self._deactivate_user("@user:test") + + # Reactivate the user with a password channel = self.make_request( - "GET", + "PUT", self.url_other_user, access_token=self.admin_user_tok, + content={"deactivated": False, "password": "foo"}, ) + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + # Reactivate the user without a password. + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"deactivated": False}, + ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(False, channel.json_body["deactivated"]) + self.assertFalse(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) self._is_erased("@user:test", False) def test_set_user_as_admin(self): @@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # Set a user as an admin - body = json.dumps({"admin": True}) - channel = self.make_request( "PUT", self.url_other_user, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"admin": True}, ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) # Get user channel = self.make_request( @@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("@user:test", channel.json_body["name"]) - self.assertEqual(True, channel.json_body["admin"]) + self.assertTrue(channel.json_body["admin"]) def test_accidental_deactivation_prevention(self): """ @@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): url = "/_synapse/admin/v2/users/@bob:test" # Create user - body = json.dumps({"password": "abc123"}) - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123"}, ) self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"]) @@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual(0, channel.json_body["deactivated"]) # Change password (and use a str for deactivate instead of a bool) - body = json.dumps({"password": "abc123", "deactivated": "false"}) # oops! - channel = self.make_request( "PUT", url, access_token=self.admin_user_tok, - content=body.encode(encoding="utf_8"), + content={"password": "abc123", "deactivated": "false"}, ) self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) @@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Ensure they're still alive self.assertEqual(0, channel.json_body["deactivated"]) - def _is_erased(self, user_id, expect): + def _is_erased(self, user_id: str, expect: bool) -> None: """Assert that the user is erased or not""" d = self.store.is_user_erased(user_id) if expect: @@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase): else: self.assertFalse(self.get_success(d)) + def _deactivate_user(self, user_id: str) -> None: + """Deactivate user and set as erased""" + + # Deactivate the user. + channel = self.make_request( + "PUT", + "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id), + access_token=self.admin_user_tok, + content={"deactivated": True}, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + self.assertTrue(channel.json_body["deactivated"]) + self.assertIsNone(channel.json_body["password_hash"]) + self._is_erased(user_id, False) + d = self.store.mark_user_erased(user_id) + self.assertIsNone(self.get_success(d)) + self._is_erased(user_id, True) + class UserMembershipRestTestCase(unittest.HomeserverTestCase): -- cgit 1.4.1