summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9636.bugfix1
-rw-r--r--synapse/handlers/set_password.py2
-rw-r--r--synapse/rest/admin/users.py2
-rw-r--r--synapse/storage/databases/main/registration.py1
-rw-r--r--tests/rest/admin/test_user.py173
5 files changed, 122 insertions, 57 deletions
diff --git a/changelog.d/9636.bugfix b/changelog.d/9636.bugfix
new file mode 100644
index 0000000000..fa772ed6fc
--- /dev/null
+++ b/changelog.d/9636.bugfix
@@ -0,0 +1 @@
+Checks if passwords are allowed before setting it for the user.
\ No newline at end of file
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 84af2dde7e..04e7c64c94 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -41,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
         logout_devices: bool,
         requester: Optional[Requester] = None,
     ) -> None:
-        if not self.hs.config.password_localdb_enabled:
+        if not self._auth_handler.can_change_password():
             raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
 
         try:
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 2c89b62e25..aaa56a7024 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -271,7 +271,7 @@ class UserRestServletV2(RestServlet):
                 elif not deactivate and user["deactivated"]:
                     if (
                         "password" not in body
-                        and self.hs.config.password_localdb_enabled
+                        and self.auth_handler.can_change_password()
                     ):
                         raise SynapseError(
                             400, "Must provide a password to re-activate an account."
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index eba66ff352..90a8f664ef 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1210,6 +1210,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
         self._invalidate_cache_and_stream(
             txn, self.get_user_deactivated_status, (user_id,)
         )
+        self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,))
         txn.call_after(self.is_guest.invalidate, (user_id,))
 
     @cached()
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index e58d5cf0db..cf61f284cb 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -1003,12 +1003,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
+        self.auth_handler = hs.get_auth_handler()
 
+        # create users and get access tokens
+        # regardless of whether password login or SSO is allowed
         self.admin_user = self.register_user("admin", "pass", admin=True)
-        self.admin_user_tok = self.login("admin", "pass")
+        self.admin_user_tok = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.admin_user, device_id=None, valid_until_ms=None
+            )
+        )
 
         self.other_user = self.register_user("user", "pass", displayname="User")
-        self.other_user_token = self.login("user", "pass")
+        self.other_user_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                self.other_user, device_id=None, valid_until_ms=None
+            )
+        )
         self.url_other_user = "/_synapse/admin/v2/users/%s" % urllib.parse.quote(
             self.other_user
         )
@@ -1081,7 +1092,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1096,9 +1107,9 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(True, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     def test_create_user(self):
@@ -1130,7 +1141,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
         # Get user
@@ -1145,10 +1156,10 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual("Bob's name", channel.json_body["displayname"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
         self.assertEqual("bob@bob.bob", channel.json_body["threepids"][0]["address"])
-        self.assertEqual(False, channel.json_body["admin"])
-        self.assertEqual(False, channel.json_body["is_guest"])
-        self.assertEqual(False, channel.json_body["deactivated"])
-        self.assertEqual(False, channel.json_body["shadow_banned"])
+        self.assertFalse(channel.json_body["admin"])
+        self.assertFalse(channel.json_body["is_guest"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["shadow_banned"])
         self.assertEqual("mxc://fibble/wibble", channel.json_body["avatar_url"])
 
     @override_config(
@@ -1197,7 +1208,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {"limit_usage_by_mau": True, "max_mau_value": 2, "mau_trial_days": 0}
@@ -1237,7 +1248,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Admin user is not blocked by mau anymore
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@bob:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["admin"])
+        self.assertFalse(channel.json_body["admin"])
 
     @override_config(
         {
@@ -1429,24 +1440,23 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
         self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1461,7 +1471,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User", channel.json_body["displayname"])
@@ -1478,41 +1489,37 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertTrue(profile["display_name"] == "User")
 
         # Deactivate user
-        body = json.dumps({"deactivated": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"deactivated": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
         # Set new displayname user
-        body = json.dumps({"displayname": "Foobar"})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"displayname": "Foobar"},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["deactivated"])
+        self.assertTrue(channel.json_body["deactivated"])
         self.assertEqual("Foobar", channel.json_body["displayname"])
 
         # is not in user directory
         profile = self.get_success(self.store.get_user_in_directory(self.other_user))
-        self.assertTrue(profile is None)
+        self.assertIsNone(profile)
 
     def test_reactivate_user(self):
         """
@@ -1520,48 +1527,92 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Attempt to reactivate the user (without a password).
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
+        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+
+        # Reactivate the user.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": True}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNotNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
-        d = self.store.mark_user_erased("@user:test")
-        self.assertIsNone(self.get_success(d))
-        self._is_erased("@user:test", True)
 
-        # Attempt to reactivate the user (without a password).
+    @override_config({"password_config": {"localdb_enabled": False}})
+    def test_reactivate_user_localdb_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False}).encode(encoding="utf_8"),
+            content={"deactivated": False, "password": "foo"},
         )
-        self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
-        # Reactivate the user.
+        # Reactivate the user without a password.
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=json.dumps({"deactivated": False, "password": "foo"}).encode(
-                encoding="utf_8"
-            ),
+            content={"deactivated": False},
         )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual("@user:test", channel.json_body["name"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased("@user:test", False)
 
-        # Get user
+    @override_config({"password_config": {"enabled": False}})
+    def test_reactivate_user_password_disabled(self):
+        """
+        Test reactivating another user when using SSO.
+        """
+
+        # Deactivate the user.
+        self._deactivate_user("@user:test")
+
+        # Reactivate the user with a password
         channel = self.make_request(
-            "GET",
+            "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
+            content={"deactivated": False, "password": "foo"},
         )
+        self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
 
+        # Reactivate the user without a password.
+        channel = self.make_request(
+            "PUT",
+            self.url_other_user,
+            access_token=self.admin_user_tok,
+            content={"deactivated": False},
+        )
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(False, channel.json_body["deactivated"])
+        self.assertFalse(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
         self._is_erased("@user:test", False)
 
     def test_set_user_as_admin(self):
@@ -1570,18 +1621,16 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         """
 
         # Set a user as an admin
-        body = json.dumps({"admin": True})
-
         channel = self.make_request(
             "PUT",
             self.url_other_user,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"admin": True},
         )
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
         # Get user
         channel = self.make_request(
@@ -1592,7 +1641,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("@user:test", channel.json_body["name"])
-        self.assertEqual(True, channel.json_body["admin"])
+        self.assertTrue(channel.json_body["admin"])
 
     def test_accidental_deactivation_prevention(self):
         """
@@ -1602,13 +1651,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v2/users/@bob:test"
 
         # Create user
-        body = json.dumps({"password": "abc123"})
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123"},
         )
 
         self.assertEqual(201, int(channel.result["code"]), msg=channel.result["body"])
@@ -1628,13 +1675,11 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, channel.json_body["deactivated"])
 
         # Change password (and use a str for deactivate instead of a bool)
-        body = json.dumps({"password": "abc123", "deactivated": "false"})  # oops!
-
         channel = self.make_request(
             "PUT",
             url,
             access_token=self.admin_user_tok,
-            content=body.encode(encoding="utf_8"),
+            content={"password": "abc123", "deactivated": "false"},
         )
 
         self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
@@ -1653,7 +1698,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         # Ensure they're still alive
         self.assertEqual(0, channel.json_body["deactivated"])
 
-    def _is_erased(self, user_id, expect):
+    def _is_erased(self, user_id: str, expect: bool) -> None:
         """Assert that the user is erased or not"""
         d = self.store.is_user_erased(user_id)
         if expect:
@@ -1661,6 +1706,24 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         else:
             self.assertFalse(self.get_success(d))
 
+    def _deactivate_user(self, user_id: str) -> None:
+        """Deactivate user and set as erased"""
+
+        # Deactivate the user.
+        channel = self.make_request(
+            "PUT",
+            "/_synapse/admin/v2/users/%s" % urllib.parse.quote(user_id),
+            access_token=self.admin_user_tok,
+            content={"deactivated": True},
+        )
+        self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+        self.assertTrue(channel.json_body["deactivated"])
+        self.assertIsNone(channel.json_body["password_hash"])
+        self._is_erased(user_id, False)
+        d = self.store.mark_user_erased(user_id)
+        self.assertIsNone(self.get_success(d))
+        self._is_erased(user_id, True)
+
 
 class UserMembershipRestTestCase(unittest.HomeserverTestCase):