summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9588.bugfix1
-rw-r--r--synapse/handlers/auth.py13
-rw-r--r--synapse/rest/client/v2_alpha/capabilities.py23
-rw-r--r--tests/rest/client/v2_alpha/test_capabilities.py36
4 files changed, 58 insertions, 15 deletions
diff --git a/changelog.d/9588.bugfix b/changelog.d/9588.bugfix
new file mode 100644
index 0000000000..b8d6140565
--- /dev/null
+++ b/changelog.d/9588.bugfix
@@ -0,0 +1 @@
+Fix the `/capabilities` endpoint to return `m.change_password` as disabled if the local password database is not used for authentication. Contributed by @dklimpel.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index fb5f8118f0..badac8c26c 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -886,6 +886,19 @@ class AuthHandler(BaseHandler):
             )
         return result
 
+    def can_change_password(self) -> bool:
+        """Get whether users on this server are allowed to change or set a password.
+
+        Both `config.password_enabled` and `config.password_localdb_enabled` must be true.
+
+        Note that any account (even SSO accounts) are allowed to add passwords if the above
+        is true.
+
+        Returns:
+            Whether users on this server are allowed to change or set a password
+        """
+        return self._password_enabled and self._password_localdb_enabled
+
     def get_supported_login_types(self) -> Iterable[str]:
         """Get a the login types supported for the /login API
 
diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index 76879ac559..44ccf10ed4 100644
--- a/synapse/rest/client/v2_alpha/capabilities.py
+++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -13,12 +13,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -27,21 +33,16 @@ class CapabilitiesRestServlet(RestServlet):
 
     PATTERNS = client_patterns("/capabilities$")
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer): server
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.config = hs.config
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
+        self.auth_handler = hs.get_auth_handler()
 
-    async def on_GET(self, request):
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
-        user = await self.store.get_user_by_id(requester.user.to_string())
-        change_password = bool(user["password_hash"])
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.auth.get_user_by_req(request, allow_guest=True)
+        change_password = self.auth_handler.can_change_password()
 
         response = {
             "capabilities": {
@@ -58,5 +59,5 @@ class CapabilitiesRestServlet(RestServlet):
         return 200, response
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
     CapabilitiesRestServlet(hs).register(http_server)
diff --git a/tests/rest/client/v2_alpha/test_capabilities.py b/tests/rest/client/v2_alpha/test_capabilities.py
index e808339fb3..287a1a485c 100644
--- a/tests/rest/client/v2_alpha/test_capabilities.py
+++ b/tests/rest/client/v2_alpha/test_capabilities.py
@@ -18,6 +18,7 @@ from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import capabilities
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class CapabilitiesTestCase(unittest.HomeserverTestCase):
@@ -33,6 +34,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver()
         self.store = hs.get_datastore()
         self.config = hs.config
+        self.auth_handler = hs.get_auth_handler()
         return hs
 
     def test_check_auth_required(self):
@@ -56,7 +58,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
             capabilities["m.room_versions"]["default"],
         )
 
-    def test_get_change_password_capabilities(self):
+    def test_get_change_password_capabilities_password_login(self):
         localpart = "user"
         password = "pass"
         user = self.register_user(localpart, password)
@@ -66,10 +68,36 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase):
         capabilities = channel.json_body["capabilities"]
 
         self.assertEqual(channel.code, 200)
-
-        # Test case where password is handled outside of Synapse
         self.assertTrue(capabilities["m.change_password"]["enabled"])
-        self.get_success(self.store.user_set_password_hash(user, None))
+
+    @override_config({"password_config": {"localdb_enabled": False}})
+    def test_get_change_password_capabilities_localdb_disabled(self):
+        localpart = "user"
+        password = "pass"
+        user = self.register_user(localpart, password)
+        access_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                user, device_id=None, valid_until_ms=None
+            )
+        )
+
+        channel = self.make_request("GET", self.url, access_token=access_token)
+        capabilities = channel.json_body["capabilities"]
+
+        self.assertEqual(channel.code, 200)
+        self.assertFalse(capabilities["m.change_password"]["enabled"])
+
+    @override_config({"password_config": {"enabled": False}})
+    def test_get_change_password_capabilities_password_disabled(self):
+        localpart = "user"
+        password = "pass"
+        user = self.register_user(localpart, password)
+        access_token = self.get_success(
+            self.auth_handler.get_access_token_for_user_id(
+                user, device_id=None, valid_until_ms=None
+            )
+        )
+
         channel = self.make_request("GET", self.url, access_token=access_token)
         capabilities = channel.json_body["capabilities"]