diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
new file mode 100644
index 0000000000..edfab8a13a
--- /dev/null
+++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,486 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the password_auth_provider interface"""
+
+from typing import Any, Type, Union
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
+# a mock instance which the dummy auth providers delegate to, so we can see what's going
+# on
+mock_password_provider = Mock()
+
+
+class PasswordOnlyAuthProvider:
+ """A password_provider which only implements `check_password`."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def check_password(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+class CustomAuthProvider:
+ """A password_provider which implements a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given password auth providers"""
+ return {
+ "password_providers": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
+class PasswordAuthProviderTests(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def setUp(self):
+ # we use a global mock device, so make sure we are starting with a clean slate
+ mock_password_provider.reset_mock()
+ super().setUp()
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_login(self):
+ # login flows should only have m.login.password
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # login with mxid should work too
+ channel = self._send_password_login("@u:bz", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:bz", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ mock_password_provider.reset_mock()
+
+ # try a weird username / pass. Honestly it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with(
+ "@ USER🙂NAME :test", " pASS😢word "
+ )
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth(self):
+ """UI Auth should delegate correctly to the password provider"""
+
+ # create the user, otherwise access doesn't work
+ module_api = self.hs.get_module_api()
+ self.get_success(module_api.register_user("u"))
+
+ # log in twice, to get two devices
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ tok1 = self.login("u", "p")
+ self.login("u", "p", device_id="dev2")
+ mock_password_provider.reset_mock()
+
+ # have the auth provider deny the request to start with
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # make the initial request which returns a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Make another request providing the UI auth flow.
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # Finally, check the request goes through when we allow it
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_login(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@localuser:test", channel.json_body["user_id"])
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # have the auth provider deny the request
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Wrong password
+ channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "xxx"
+ )
+ mock_password_provider.reset_mock()
+
+ # Right password
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_login(self):
+ """localdb_enabled can block login with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_ui_auth(self):
+ """localdb_enabled can block ui auth with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # allow login via the auth provider
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "p")
+ self.login("localuser", "p", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # now try deleting with the local password
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_auth_disabled(self):
+ """password auth doesn't work if it's disabled across the board"""
+ # login flows should be empty
+ flows = self._get_login_flows()
+ self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_login(self):
+ # login flows should have the custom flow and m.login.password, since we
+ # haven't disabled local password lookup.
+ # (password must come first, because reasons)
+ flows = self._get_login_flows()
+ self.assertEqual(
+ flows,
+ [{"type": "m.login.password"}, {"type": "test.login_type"}]
+ + ADDITIONAL_LOGIN_FLOWS,
+ )
+
+ # login with missing param should be rejected
+ channel = self._send_login("test.login_type", "u")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+ mock_password_provider.reset_mock()
+
+ # try a weird username. Again, it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@ MALFORMED! :bz"
+ )
+ channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_ui_auth(self):
+ # register the user and log in twice, to get two devices
+ self.register_user("localuser", "localpass")
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ # missing param
+ body = {
+ "auth": {
+ "type": "test.login_type",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ # FIXME "identifier" is ignored
+ # https://github.com/matrix-org/synapse/issues/5665
+ "user": "localuser",
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
+ # use it...
+ self.assertIn("Missing parameters", channel.json_body["error"])
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # right params, but authing as the wrong user
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ body["auth"]["test_field"] = "foo"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+ mock_password_provider.reset_mock()
+
+ # and finally, succeed
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_callback(self):
+ callback = Mock(return_value=defer.succeed(None))
+
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", callback)
+ )
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+
+ # check the args to the callback
+ callback.assert_called_once()
+ call_args, call_kwargs = callback.call_args
+ # should be one positional arg
+ self.assertEqual(len(call_args), 1)
+ self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ for p in ["user_id", "access_token", "device_id", "home_server"]:
+ self.assertIn(p, call_args[0])
+
+ @override_config(
+ {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
+ )
+ def test_custom_auth_password_disabled(self):
+ """Test login with a custom auth provider where password login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback(self):
+ """Test login with a custom auth provider where the local db is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # password login shouldn't work and should be rejected with a 400
+ # ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+
+ test_custom_auth_no_local_user_fallback.skip = "currently broken"
+
+ def _get_login_flows(self) -> JsonDict:
+ _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["flows"]
+
+ def _send_password_login(self, user: str, password: str) -> FakeChannel:
+ return self._send_login(type="m.login.password", user=user, password=password)
+
+ def _send_login(self, type, user, **params) -> FakeChannel:
+ params.update({"user": user, "type": type})
+ _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ return channel
+
+ def _start_delete_device_session(self, access_token, device_id) -> str:
+ """Make an initial delete device request, and return the UI Auth session ID"""
+ channel = self._delete_device(access_token, device_id)
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ return channel.json_body["session"]
+
+ def _authed_delete_device(
+ self,
+ access_token: str,
+ device_id: str,
+ session: str,
+ user_id: str,
+ password: str,
+ ) -> FakeChannel:
+ """Make a delete device request, authenticating with the given uid/password"""
+ return self._delete_device(
+ access_token,
+ device_id,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": user_id},
+ # FIXME "identifier" is ignored
+ # https://github.com/matrix-org/synapse/issues/5665
+ "user": user_id,
+ "password": password,
+ "session": session,
+ },
+ },
+ )
+
+ def _delete_device(
+ self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ _, channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token
+ )
+ return channel
|