From ddc43436838e19a7dd16860389bd76c74578dae7 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 11:10:42 +0000 Subject: Add some tests for `password_auth_providers` (#8819) These things seemed to be completely untested, so I added a load of tests for them. --- tests/handlers/test_password_providers.py | 486 ++++++++++++++++++++++++++++++ 1 file changed, 486 insertions(+) create mode 100644 tests/handlers/test_password_providers.py (limited to 'tests') diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py new file mode 100644 index 0000000000..edfab8a13a --- /dev/null +++ b/tests/handlers/test_password_providers.py @@ -0,0 +1,486 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the password_auth_provider interface""" + +from typing import Any, Type, Union + +from mock import Mock + +from twisted.internet import defer + +import synapse +from synapse.rest.client.v1 import login +from synapse.rest.client.v2_alpha import devices +from synapse.types import JsonDict + +from tests import unittest +from tests.server import FakeChannel +from tests.unittest import override_config + +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + +# a mock instance which the dummy auth providers delegate to, so we can see what's going +# on +mock_password_provider = Mock() + + +class PasswordOnlyAuthProvider: + """A password_provider which only implements `check_password`.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def check_password(self, *args): + return mock_password_provider.check_password(*args) + + +class CustomAuthProvider: + """A password_provider which implements a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given password auth providers""" + return { + "password_providers": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + +class PasswordAuthProviderTests(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + devices.register_servlets, + ] + + def setUp(self): + # we use a global mock device, so make sure we are starting with a clean slate + mock_password_provider.reset_mock() + super().setUp() + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_login(self): + # login flows should only have m.login.password + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # login with mxid should work too + channel = self._send_password_login("@u:bz", "p") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@u:bz", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") + mock_password_provider.reset_mock() + + # try a weird username / pass. Honestly it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with( + "@ USER🙂NAME :test", " pASS😢word " + ) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth(self): + """UI Auth should delegate correctly to the password provider""" + + # create the user, otherwise access doesn't work + module_api = self.hs.get_module_api() + self.get_success(module_api.register_user("u")) + + # log in twice, to get two devices + mock_password_provider.check_password.return_value = defer.succeed(True) + tok1 = self.login("u", "p") + self.login("u", "p", device_id="dev2") + mock_password_provider.reset_mock() + + # have the auth provider deny the request to start with + mock_password_provider.check_password.return_value = defer.succeed(False) + + # make the initial request which returns a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Make another request providing the UI auth flow. + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + mock_password_provider.reset_mock() + + # Finally, check the request goes through when we allow it + mock_password_provider.check_password.return_value = defer.succeed(True) + channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_login(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 403, channel.result) + + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@localuser:test", channel.json_body["user_id"]) + + @override_config(providers_config(PasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth(self): + """rejected login should fall back to local db""" + self.register_user("localuser", "localpass") + + # have the auth provider deny the request + mock_password_provider.check_password.return_value = defer.succeed(False) + + # log in twice, to get two devices + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # Wrong password + channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx") + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "xxx" + ) + mock_password_provider.reset_mock() + + # Right password + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 200) + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_login(self): + """localdb_enabled can block login with the local password + """ + self.register_user("localuser", "localpass") + + # check_password must return an awaitable + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_no_local_user_fallback_ui_auth(self): + """localdb_enabled can block ui auth with the local password + """ + self.register_user("localuser", "localpass") + + # allow login via the auth provider + mock_password_provider.check_password.return_value = defer.succeed(True) + + # log in twice, to get two devices + tok1 = self.login("localuser", "p") + self.login("localuser", "p", device_id="dev2") + mock_password_provider.check_password.reset_mock() + + # first delete should give a 401 + session = self._start_delete_device_session(tok1, "dev2") + mock_password_provider.check_password.assert_not_called() + + # now try deleting with the local password + mock_password_provider.check_password.return_value = defer.succeed(False) + channel = self._authed_delete_device( + tok1, "dev2", session, "localuser", "localpass" + ) + self.assertEqual(channel.code, 401) # XXX why not a 403? + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_password.assert_called_once_with( + "@localuser:test", "localpass" + ) + + @override_config( + { + **providers_config(PasswordOnlyAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_auth_disabled(self): + """password auth doesn't work if it's disabled across the board""" + # login flows should be empty + flows = self._get_login_flows() + self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("u", "p") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_password.assert_not_called() + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_login(self): + # login flows should have the custom flow and m.login.password, since we + # haven't disabled local password lookup. + # (password must come first, because reasons) + flows = self._get_login_flows() + self.assertEqual( + flows, + [{"type": "m.login.password"}, {"type": "test.login_type"}] + + ADDITIONAL_LOGIN_FLOWS, + ) + + # login with missing param should be rejected + channel = self._send_login("test.login_type", "u") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + mock_password_provider.reset_mock() + + # try a weird username. Again, it's unclear what we *expect* to happen + # in these cases, but at least we can guard against the API changing + # unexpectedly + mock_password_provider.check_auth.return_value = defer.succeed( + "@ MALFORMED! :bz" + ) + channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + " USER🙂NAME ", "test.login_type", {"test_field": " abc "} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_ui_auth(self): + # register the user and log in twice, to get two devices + self.register_user("localuser", "localpass") + tok1 = self.login("localuser", "localpass") + self.login("localuser", "localpass", device_id="dev2") + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + # missing param + body = { + "auth": { + "type": "test.login_type", + "identifier": {"type": "m.id.user", "user": "localuser"}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": "localuser", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should + # use it... + self.assertIn("Missing parameters", channel.json_body["error"]) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # right params, but authing as the wrong user + mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + body["auth"]["test_field"] = "foo" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 403) + self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + mock_password_provider.reset_mock() + + # and finally, succeed + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "foo"} + ) + + @override_config(providers_config(CustomAuthProvider)) + def test_custom_auth_provider_callback(self): + callback = Mock(return_value=defer.succeed(None)) + + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", callback) + ) + channel = self._send_login("test.login_type", "u", test_field="y") + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual("@user:bz", channel.json_body["user_id"]) + mock_password_provider.check_auth.assert_called_once_with( + "u", "test.login_type", {"test_field": "y"} + ) + + # check the args to the callback + callback.assert_called_once() + call_args, call_kwargs = callback.call_args + # should be one positional arg + self.assertEqual(len(call_args), 1) + self.assertEqual(call_args[0]["user_id"], "@user:bz") + for p in ["user_id", "access_token", "device_id", "home_server"]: + self.assertIn(p, call_args[0]) + + @override_config( + {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} + ) + def test_custom_auth_password_disabled(self): + """Test login with a custom auth provider where password login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(CustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback(self): + """Test login with a custom auth provider where the local db is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # password login shouldn't work and should be rejected with a 400 + # ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + + test_custom_auth_no_local_user_fallback.skip = "currently broken" + + def _get_login_flows(self) -> JsonDict: + _, channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + return channel.json_body["flows"] + + def _send_password_login(self, user: str, password: str) -> FakeChannel: + return self._send_login(type="m.login.password", user=user, password=password) + + def _send_login(self, type, user, **params) -> FakeChannel: + params.update({"user": user, "type": type}) + _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) + return channel + + def _start_delete_device_session(self, access_token, device_id) -> str: + """Make an initial delete device request, and return the UI Auth session ID""" + channel = self._delete_device(access_token, device_id) + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + return channel.json_body["session"] + + def _authed_delete_device( + self, + access_token: str, + device_id: str, + session: str, + user_id: str, + password: str, + ) -> FakeChannel: + """Make a delete device request, authenticating with the given uid/password""" + return self._delete_device( + access_token, + device_id, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": user_id}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": user_id, + "password": password, + "session": session, + }, + }, + ) + + def _delete_device( + self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"", + ) -> FakeChannel: + """Delete an individual device.""" + _, channel = self.make_request( + "DELETE", "devices/" + device, body, access_token=access_token + ) + return channel -- cgit 1.5.1 From 89f79307306ed117d9dcfe46a31a3fe1a1a5ceae Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 13:04:03 +0000 Subject: Don't offer password login when it is disabled (#8835) Fix a minor bug where we would offer "m.login.password" login if a custom auth provider supported it, even if password login was disabled. --- changelog.d/8835.bugfix | 1 + synapse/handlers/auth.py | 10 ++- tests/handlers/test_password_providers.py | 108 +++++++++++++++++++++++++++++- 3 files changed, 115 insertions(+), 4 deletions(-) create mode 100644 changelog.d/8835.bugfix (limited to 'tests') diff --git a/changelog.d/8835.bugfix b/changelog.d/8835.bugfix new file mode 100644 index 0000000000..446d04aa55 --- /dev/null +++ b/changelog.d/8835.bugfix @@ -0,0 +1 @@ +Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 5163afd86c..588d3a60df 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -205,15 +205,23 @@ class AuthHandler(BaseHandler): # type in the list. (NB that the spec doesn't require us to do so and # clients which favour types that they don't understand over those that # they do are technically broken) + + # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = [] - if self._password_enabled: + if hs.config.password_localdb_enabled: login_types.append(LoginType.PASSWORD) + for provider in self.password_providers: if hasattr(provider, "get_supported_login_types"): for t in provider.get_supported_login_types().keys(): if t not in login_types: login_types.append(t) + + if not self._password_enabled: + login_types.remove(LoginType.PASSWORD) + self._supported_login_types = login_types + # Login types and UI Auth types have a heavy overlap, but are not # necessarily identical. Login types have SSO (and other login types) # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index edfab8a13a..dfbc4ee07e 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -70,6 +70,24 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) +class PasswordCustomAuthProvider: + """A password_provider which implements password login via `check_auth`, as well + as a custom type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, account_handler): + pass + + def get_supported_login_types(self): + return {"m.login.password": ["password"], "test.login_type": ["test_field"]} + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def providers_config(*providers: Type[Any]) -> dict: """Returns a config dict that will enable the given password auth providers""" return { @@ -246,7 +264,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.reset_mock() # first delete should give a 401 - session = self._start_delete_device_session(tok1, "dev2") + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # there are no valid flows here! + self.assertEqual(channel.json_body["flows"], []) + session = channel.json_body["session"] mock_password_provider.check_password.assert_not_called() # now try deleting with the local password @@ -410,6 +432,88 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login(self): + """log in with a custom auth provider which implements password, but password + login is disabled""" + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { + **providers_config(PasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth(self): + """UI Auth with a custom auth provider which implements password, but password + login is disabled""" + # register the user and log in twice via the test login type to get two devices, + self.register_user("localuser", "localpass") + mock_password_provider.check_auth.return_value = defer.succeed( + "@localuser:test" + ) + channel = self._send_login("test.login_type", "localuser", test_field="") + self.assertEqual(channel.code, 200, channel.result) + tok1 = channel.json_body["access_token"] + + channel = self._send_login( + "test.login_type", "localuser", test_field="", device_id="dev2" + ) + self.assertEqual(channel.code, 200, channel.result) + + # make the initial request which returns a 401 + channel = self._delete_device(tok1, "dev2") + self.assertEqual(channel.code, 401) + # Ensure that flows are what is expected. In particular, "password" should *not* + # be present. + self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"]) + session = channel.json_body["session"] + + mock_password_provider.reset_mock() + + # check that auth with password is rejected + body = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": "localuser"}, + # FIXME "identifier" is ignored + # https://github.com/matrix-org/synapse/issues/5665 + "user": "localuser", + "password": "localpass", + "session": session, + }, + } + + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 400) + self.assertEqual( + "Password login has been disabled.", channel.json_body["error"] + ) + mock_password_provider.check_auth.assert_not_called() + mock_password_provider.reset_mock() + + # successful auth + body["auth"]["type"] = "test.login_type" + body["auth"]["test_field"] = "x" + channel = self._delete_device(tok1, "dev2", body) + self.assertEqual(channel.code, 200) + mock_password_provider.check_auth.assert_called_once_with( + "localuser", "test.login_type", {"test_field": "x"} + ) + @override_config( { **providers_config(CustomAuthProvider), @@ -428,8 +532,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) - test_custom_auth_no_local_user_fallback.skip = "currently broken" - def _get_login_flows(self) -> JsonDict: _, channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) -- cgit 1.5.1 From 4d9496559d25ba36eaea45d73e67e79b9d936450 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Dec 2020 17:42:26 +0000 Subject: Support "identifier" dicts in UIA (#8848) The spec requires synapse to support `identifier` dicts for `m.login.password` user-interactive auth, which it did not (instead, it required an undocumented `user` parameter.) To fix this properly, we need to pull the code that interprets `identifier` into `AuthHandler.validate_login` so that it can be called from the UIA code. Fixes #5665. --- changelog.d/8848.bugfix | 1 + synapse/handlers/auth.py | 185 ++++++++++++++++++++++++++---- synapse/rest/client/v1/login.py | 107 +---------------- tests/handlers/test_password_providers.py | 11 +- tests/rest/client/v2_alpha/test_auth.py | 33 ++++-- 5 files changed, 190 insertions(+), 147 deletions(-) create mode 100644 changelog.d/8848.bugfix (limited to 'tests') diff --git a/changelog.d/8848.bugfix b/changelog.d/8848.bugfix new file mode 100644 index 0000000000..499e66f05b --- /dev/null +++ b/changelog.d/8848.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 588d3a60df..8815f685b9 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -238,6 +238,13 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # Ratelimitier for failed /login attempts + self._failed_login_attempts_ratelimiter = Ratelimiter( + clock=hs.get_clock(), + rate_hz=self.hs.config.rc_login_failed_attempts.per_second, + burst_count=self.hs.config.rc_login_failed_attempts.burst_count, + ) + self._clock = self.hs.get_clock() # Expire old UI auth sessions after a period of time. @@ -650,14 +657,8 @@ class AuthHandler(BaseHandler): res = await checker.check_auth(authdict, clientip=clientip) return res - # build a v1-login-style dict out of the authdict and fall back to the - # v1 code - user_id = authdict.get("user") - - if user_id is None: - raise SynapseError(400, "", Codes.MISSING_PARAM) - - (canonical_id, callback) = await self.validate_login(user_id, authdict) + # fall back to the v1 login flow + canonical_id, _ = await self.validate_login(authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -832,17 +833,17 @@ class AuthHandler(BaseHandler): return self._supported_login_types async def validate_login( - self, username: str, login_submission: Dict[str, Any] + self, login_submission: Dict[str, Any], ratelimit: bool = False, ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API - Also used by the user-interactive auth flow to validate - m.login.password auth types. + Also used by the user-interactive auth flow to validate auth types which don't + have an explicit UIA handler, including m.password.auth. Args: - username: username supplied by the user login_submission: the whole of the login submission (including 'type' and other relevant fields) + ratelimit: whether to apply the failed_login_attempt ratelimiter Returns: A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued @@ -851,29 +852,161 @@ class AuthHandler(BaseHandler): SynapseError if there was a problem with the request LoginError if there was an authentication problem. """ - - if username.startswith("@"): - qualified_user_id = username - else: - qualified_user_id = UserID(username, self.hs.hostname).to_string() - login_type = login_submission.get("type") - known_login_type = False + + # ideally, we wouldn't be checking the identifier unless we know we have a login + # method which uses it (https://github.com/matrix-org/synapse/issues/8836) + # + # But the auth providers' check_auth interface requires a username, so in + # practice we can only support login methods which we can map to a username + # anyway. # special case to check for "password" for the check_password interface # for the auth providers password = login_submission.get("password") - if login_type == LoginType.PASSWORD: if not self._password_enabled: raise SynapseError(400, "Password login has been disabled.") - if not password: - raise SynapseError(400, "Missing parameter: password") + if not isinstance(password, str): + raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM) + + # map old-school login fields into new-school "identifier" fields. + identifier_dict = convert_client_dict_legacy_fields_to_identifier( + login_submission + ) + + # convert phone type identifiers to generic threepids + if identifier_dict["type"] == "m.id.phone": + identifier_dict = login_id_phone_to_thirdparty(identifier_dict) + + # convert threepid identifiers to user IDs + if identifier_dict["type"] == "m.id.thirdparty": + address = identifier_dict.get("address") + medium = identifier_dict.get("medium") + + if medium is None or address is None: + raise SynapseError(400, "Invalid thirdparty identifier") + + # For emails, canonicalise the address. + # We store all email addresses canonicalised in the DB. + # (See add_threepid in synapse/handlers/auth.py) + if medium == "email": + try: + address = canonicalise_email(address) + except ValueError as e: + raise SynapseError(400, str(e)) + + # We also apply account rate limiting using the 3PID as a key, as + # otherwise using 3PID bypasses the ratelimiting based on user ID. + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + (medium, address), update=False + ) + + # Check for login providers that support 3pid login types + if login_type == LoginType.PASSWORD: + # we've already checked that there is a (valid) password field + assert isinstance(password, str) + ( + canonical_user_id, + callback_3pid, + ) = await self.check_password_provider_3pid(medium, address, password) + if canonical_user_id: + # Authentication through password provider and 3pid succeeded + return canonical_user_id, callback_3pid + + # No password providers were able to handle this 3pid + # Check local store + user_id = await self.hs.get_datastore().get_user_id_by_threepid( + medium, address + ) + if not user_id: + logger.warning( + "unknown 3pid identifier medium %s, address %r", medium, address + ) + # We mark that we've failed to log in here, as + # `check_password_provider_3pid` might have returned `None` due + # to an incorrect password, rather than the account not + # existing. + # + # If it returned None but the 3PID was bound then we won't hit + # this code path, which is fine as then the per-user ratelimit + # will kick in below. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + (medium, address) + ) + raise LoginError(403, "", errcode=Codes.FORBIDDEN) + + identifier_dict = {"type": "m.id.user", "user": user_id} + + # by this point, the identifier should be an m.id.user: if it's anything + # else, we haven't understood it. + if identifier_dict["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + + username = identifier_dict.get("user") + if not username: + raise SynapseError(400, "User identifier is missing 'user' key") + + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + # Check if we've hit the failed ratelimit (but don't update it) + if ratelimit: + self._failed_login_attempts_ratelimiter.ratelimit( + qualified_user_id.lower(), update=False + ) + + try: + return await self._validate_userid_login(username, login_submission) + except LoginError: + # The user has failed to log in, so we need to update the rate + # limiter. Using `can_do_action` avoids us raising a ratelimit + # exception and masking the LoginError. The actual ratelimiting + # should have happened above. + if ratelimit: + self._failed_login_attempts_ratelimiter.can_do_action( + qualified_user_id.lower() + ) + raise + + async def _validate_userid_login( + self, username: str, login_submission: Dict[str, Any], + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: + """Helper for validate_login + + Handles login, once we've mapped 3pids onto userids + + Args: + username: the username, from the identifier dict + login_submission: the whole of the login submission + (including 'type' and other relevant fields) + Returns: + A tuple of the canonical user id, and optional callback + to be called once the access token and device id are issued + Raises: + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. + """ + if username.startswith("@"): + qualified_user_id = username + else: + qualified_user_id = UserID(username, self.hs.hostname).to_string() + + login_type = login_submission.get("type") + known_login_type = False for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = await provider.check_password(qualified_user_id, password) + # we've already checked that there is a (valid) password field + is_valid = await provider.check_password( + qualified_user_id, login_submission["password"] + ) if is_valid: return qualified_user_id, None @@ -914,8 +1047,12 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True + # we've already checked that there is a (valid) password field + password = login_submission["password"] + assert isinstance(password, str) + canonical_user_id = await self._check_local_password( - qualified_user_id, password # type: ignore + qualified_user_id, password ) if canonical_user_id: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 074bdd66c9..d7ae148214 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.handlers.auth import ( - convert_client_dict_legacy_fields_to_identifier, - login_id_phone_to_thirdparty, -) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) @@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet): rate_hz=self.hs.config.rc_login_account.per_second, burst_count=self.hs.config.rc_login_account.burst_count, ) - self._failed_attempts_ratelimiter = Ratelimiter( - clock=hs.get_clock(), - rate_hz=self.hs.config.rc_login_failed_attempts.per_second, - burst_count=self.hs.config.rc_login_failed_attempts.burst_count, - ) def on_GET(self, request: SynapseRequest): flows = [] @@ -140,17 +130,6 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result - def _get_qualified_user_id(self, identifier): - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - return identifier["user"] - else: - return UserID(identifier["user"], self.hs.hostname).to_string() - async def _do_appservice_login( self, login_submission: JsonDict, appservice: ApplicationService ): @@ -201,91 +180,9 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) - - # convert phone type identifiers to generic threepids - if identifier["type"] == "m.id.phone": - identifier = login_id_phone_to_thirdparty(identifier) - - # convert threepid identifiers to user IDs - if identifier["type"] == "m.id.thirdparty": - address = identifier.get("address") - medium = identifier.get("medium") - - if medium is None or address is None: - raise SynapseError(400, "Invalid thirdparty identifier") - - # For emails, canonicalise the address. - # We store all email addresses canonicalised in the DB. - # (See add_threepid in synapse/handlers/auth.py) - if medium == "email": - try: - address = canonicalise_email(address) - except ValueError as e: - raise SynapseError(400, str(e)) - - # We also apply account rate limiting using the 3PID as a key, as - # otherwise using 3PID bypasses the ratelimiting based on user ID. - self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False) - - # Check for login providers that support 3pid login types - ( - canonical_user_id, - callback_3pid, - ) = await self.auth_handler.check_password_provider_3pid( - medium, address, login_submission["password"] - ) - if canonical_user_id: - # Authentication through password provider and 3pid succeeded - - result = await self._complete_login( - canonical_user_id, login_submission, callback_3pid - ) - return result - - # No password providers were able to handle this 3pid - # Check local store - user_id = await self.hs.get_datastore().get_user_id_by_threepid( - medium, address - ) - if not user_id: - logger.warning( - "unknown 3pid identifier medium %s, address %r", medium, address - ) - # We mark that we've failed to log in here, as - # `check_password_provider_3pid` might have returned `None` due - # to an incorrect password, rather than the account not - # existing. - # - # If it returned None but the 3PID was bound then we won't hit - # this code path, which is fine as then the per-user ratelimit - # will kick in below. - self._failed_attempts_ratelimiter.can_do_action((medium, address)) - raise LoginError(403, "", errcode=Codes.FORBIDDEN) - - identifier = {"type": "m.id.user", "user": user_id} - - # by this point, the identifier should be an m.id.user: if it's anything - # else, we haven't understood it. - qualified_user_id = self._get_qualified_user_id(identifier) - - # Check if we've hit the failed ratelimit (but don't update it) - self._failed_attempts_ratelimiter.ratelimit( - qualified_user_id.lower(), update=False + canonical_user_id, callback = await self.auth_handler.validate_login( + login_submission, ratelimit=True ) - - try: - canonical_user_id, callback = await self.auth_handler.validate_login( - identifier["user"], login_submission - ) - except LoginError: - # The user has failed to log in, so we need to update the rate - # limiter. Using `can_do_action` avoids us raising a ratelimit - # exception and masking the LoginError. The actual ratelimiting - # should have happened above. - self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower()) - raise - result = await self._complete_login( canonical_user_id, login_submission, callback ) diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index dfbc4ee07e..22b9a11dc0 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -358,9 +358,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "test.login_type", "identifier": {"type": "m.id.user", "user": "localuser"}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": "localuser", "session": session, }, } @@ -489,9 +486,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "localuser"}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": "localuser", "password": "localpass", "session": session, }, @@ -541,7 +535,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): return self._send_login(type="m.login.password", user=user, password=password) def _send_login(self, type, user, **params) -> FakeChannel: - params.update({"user": user, "type": type}) + params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type}) _, channel = self.make_request("POST", "/_matrix/client/r0/login", params) return channel @@ -569,9 +563,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": user_id}, - # FIXME "identifier" is ignored - # https://github.com/matrix-org/synapse/issues/5665 - "user": user_id, "password": password, "session": session, }, diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index f684c37db5..77246e478f 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker): return succeed(True) -class DummyPasswordChecker(UserInteractiveAuthChecker): - def check_auth(self, authdict, clientip): - return succeed(authdict["identifier"]["user"]) - - class FallbackAuthTests(unittest.HomeserverTestCase): servlets = [ @@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase): ] def prepare(self, reactor, clock, hs): - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs) - self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.user_tok = self.login("test", self.user_pass) @@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + def test_grandfathered_identifier(self): + """Check behaviour without "identifier" dict + + Synapse used to require clients to submit a "user" field for m.login.password + UIA - check that still works. + """ + + device_id = self.get_device_ids()[0] + channel = self.delete_device(device_id, 401) + session = channel.json_body["session"] + + # Make another request providing the UI auth flow. + self.delete_device( + device_id, + 200, + { + "auth": { + "type": "m.login.password", + "user": self.user, + "password": self.user_pass, + "session": session, + }, + }, + ) + def test_can_change_body(self): """ The client dict can be modified during the user interactive authentication session. -- cgit 1.5.1 From edb3d3f82716c2b5c903ddb4d0df155e06c5c9e9 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Wed, 2 Dec 2020 10:38:18 +0000 Subject: Allow specifying room version in 'RestHelper.create_room_as' and add typing (#8854) This PR adds a `room_version` argument to the `RestHelper`'s `create_room_as` function for tests. I plan to use this for testing knocking, which currently uses an unstable room version. --- changelog.d/8854.misc | 1 + tests/rest/client/v1/utils.py | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8854.misc (limited to 'tests') diff --git a/changelog.d/8854.misc b/changelog.d/8854.misc new file mode 100644 index 0000000000..5895df2d5c --- /dev/null +++ b/changelog.d/8854.misc @@ -0,0 +1 @@ +Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`. \ No newline at end of file diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index b58768675b..737c38c396 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -41,14 +41,37 @@ class RestHelper: auth_user_id = attr.ib() def create_room_as( - self, room_creator=None, is_public=True, tok=None, expect_code=200, - ): + self, + room_creator: str = None, + is_public: bool = True, + room_version: str = None, + tok: str = None, + expect_code: int = 200, + ) -> str: + """ + Create a room. + + Args: + room_creator: The user ID to create the room with. + is_public: If True, the `visibility` parameter will be set to the + default (public). Otherwise, the `visibility` parameter will be set + to "private". + room_version: The room version to create the room as. Defaults to Synapse's + default room version. + tok: The access token to use in the request. + expect_code: The expected HTTP response code. + + Returns: + The ID of the newly created room. + """ temp_id = self.auth_user_id self.auth_user_id = room_creator path = "/_matrix/client/r0/createRoom" content = {} if not is_public: content["visibility"] = "private" + if room_version: + content["room_version"] = room_version if tok: path = path + "?access_token=%s" % tok -- cgit 1.5.1 From d3ed93504bb6bb8ad138e356e3c74b6a7286299b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 2 Dec 2020 10:38:50 +0000 Subject: Create a `PasswordProvider` wrapper object (#8849) The idea here is to abstract out all the conditional code which tests which methods a given password provider has, to provide a consistent interface. --- changelog.d/8849.misc | 1 + synapse/handlers/auth.py | 203 ++++++++++++++++++++++-------- tests/handlers/test_password_providers.py | 5 +- 3 files changed, 152 insertions(+), 57 deletions(-) create mode 100644 changelog.d/8849.misc (limited to 'tests') diff --git a/changelog.d/8849.misc b/changelog.d/8849.misc new file mode 100644 index 0000000000..3dd496ce61 --- /dev/null +++ b/changelog.d/8849.misc @@ -0,0 +1 @@ +Refactor `password_auth_provider` support code. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 8815f685b9..c7dc07008a 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright 2014 - 2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd +# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,6 +26,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Tuple, Union, @@ -181,17 +183,12 @@ class AuthHandler(BaseHandler): # better way to break the loop account_handler = ModuleApi(hs, self) - self.password_providers = [] - for module, config in hs.config.password_providers: - try: - self.password_providers.append( - module(config=config, account_handler=account_handler) - ) - except Exception as e: - logger.error("Error while initializing %r: %s", module, e) - raise + self.password_providers = [ + PasswordProvider.load(module, config, account_handler) + for module, config in hs.config.password_providers + ] - logger.info("Extra password_providers: %r", self.password_providers) + logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() @@ -853,6 +850,8 @@ class AuthHandler(BaseHandler): LoginError if there was an authentication problem. """ login_type = login_submission.get("type") + if not isinstance(login_type, str): + raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM) # ideally, we wouldn't be checking the identifier unless we know we have a login # method which uses it (https://github.com/matrix-org/synapse/issues/8836) @@ -998,24 +997,12 @@ class AuthHandler(BaseHandler): qualified_user_id = UserID(username, self.hs.hostname).to_string() login_type = login_submission.get("type") + # we already checked that we have a valid login type + assert isinstance(login_type, str) + known_login_type = False for provider in self.password_providers: - if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: - known_login_type = True - # we've already checked that there is a (valid) password field - is_valid = await provider.check_password( - qualified_user_id, login_submission["password"] - ) - if is_valid: - return qualified_user_id, None - - if not hasattr(provider, "get_supported_login_types") or not hasattr( - provider, "check_auth" - ): - # this password provider doesn't understand custom login types - continue - supported_login_types = provider.get_supported_login_types() if login_type not in supported_login_types: # this password provider doesn't understand this login type @@ -1040,8 +1027,6 @@ class AuthHandler(BaseHandler): result = await provider.check_auth(username, login_type, login_dict) if result: - if isinstance(result, str): - result = (result, None) return result if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: @@ -1083,19 +1068,9 @@ class AuthHandler(BaseHandler): unsuccessful, `user_id` and `callback` are both `None`. """ for provider in self.password_providers: - if hasattr(provider, "check_3pid_auth"): - # This function is able to return a deferred that either - # resolves None, meaning authentication failure, or upon - # success, to a str (which is the user_id) or a tuple of - # (user_id, callback_func), where callback_func should be run - # after we've finished everything else - result = await provider.check_3pid_auth(medium, address, password) - if result: - # Check if the return value is a str or a tuple - if isinstance(result, str): - # If it's a str, set callback function to None - result = (result, None) - return result + result = await provider.check_3pid_auth(medium, address, password) + if result: + return result return None, None @@ -1153,16 +1128,11 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - # This might return an awaitable, if it does block the log out - # until it completes. - result = provider.on_logged_out( - user_id=user_info.user_id, - device_id=user_info.device_id, - access_token=access_token, - ) - if inspect.isawaitable(result): - await result + await provider.on_logged_out( + user_id=user_info.user_id, + device_id=user_info.device_id, + access_token=access_token, + ) # delete pushers associated with this access token if user_info.token_id is not None: @@ -1191,11 +1161,10 @@ class AuthHandler(BaseHandler): # see if any of our auth providers want to know about this for provider in self.password_providers: - if hasattr(provider, "on_logged_out"): - for token, token_id, device_id in tokens_and_devices: - await provider.on_logged_out( - user_id=user_id, device_id=device_id, access_token=token - ) + for token, token_id, device_id in tokens_and_devices: + await provider.on_logged_out( + user_id=user_id, device_id=device_id, access_token=token + ) # delete pushers associated with the access tokens await self.hs.get_pusherpool().remove_pushers_by_access_token( @@ -1519,3 +1488,127 @@ class MacaroonGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon + + +class PasswordProvider: + """Wrapper for a password auth provider module + + This class abstracts out all of the backwards-compatibility hacks for + password providers, to provide a consistent interface. + """ + + @classmethod + def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider": + try: + pp = module(config=config, account_handler=module_api) + except Exception as e: + logger.error("Error while initializing %r: %s", module, e) + raise + return cls(pp, module_api) + + def __init__(self, pp, module_api: ModuleApi): + self._pp = pp + self._module_api = module_api + + self._supported_login_types = {} + + # grandfather in check_password support + if hasattr(self._pp, "check_password"): + self._supported_login_types[LoginType.PASSWORD] = ("password",) + + g = getattr(self._pp, "get_supported_login_types", None) + if g: + self._supported_login_types.update(g()) + + def __str__(self): + return str(self._pp) + + def get_supported_login_types(self) -> Mapping[str, Iterable[str]]: + """Get the login types supported by this password provider + + Returns a map from a login type identifier (such as m.login.password) to an + iterable giving the fields which must be provided by the user in the submission + to the /login API. + + This wrapper adds m.login.password to the list if the underlying password + provider supports the check_password() api. + """ + return self._supported_login_types + + async def check_auth( + self, username: str, login_type: str, login_dict: JsonDict + ) -> Optional[Tuple[str, Optional[Callable]]]: + """Check if the user has presented valid login credentials + + This wrapper also calls check_password() if the underlying password provider + supports the check_password() api and the login type is m.login.password. + + Args: + username: user id presented by the client. Either an MXID or an unqualified + username. + + login_type: the login type being attempted - one of the types returned by + get_supported_login_types() + + login_dict: the dictionary of login secrets passed by the client. + + Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the + user, and `callback` is an optional callback which will be called with the + result from the /login call (including access_token, device_id, etc.) + """ + # first grandfather in a call to check_password + if login_type == LoginType.PASSWORD: + g = getattr(self._pp, "check_password", None) + if g: + qualified_user_id = self._module_api.get_qualified_user_id(username) + is_valid = await self._pp.check_password( + qualified_user_id, login_dict["password"] + ) + if is_valid: + return qualified_user_id, None + + g = getattr(self._pp, "check_auth", None) + if not g: + return None + result = await g(username, login_type, login_dict) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def check_3pid_auth( + self, medium: str, address: str, password: str + ) -> Optional[Tuple[str, Optional[Callable]]]: + g = getattr(self._pp, "check_3pid_auth", None) + if not g: + return None + + # This function is able to return a deferred that either + # resolves None, meaning authentication failure, or upon + # success, to a str (which is the user_id) or a tuple of + # (user_id, callback_func), where callback_func should be run + # after we've finished everything else + result = await g(medium, address, password) + + # Check if the return value is a str or a tuple + if isinstance(result, str): + # If it's a str, set callback function to None + return result, None + + return result + + async def on_logged_out( + self, user_id: str, device_id: Optional[str], access_token: str + ) -> None: + g = getattr(self._pp, "on_logged_out", None) + if not g: + return + + # This might return an awaitable, if it does block the log out + # until it completes. + result = g(user_id=user_id, device_id=device_id, access_token=access_token,) + if inspect.isawaitable(result): + await result diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 22b9a11dc0..ceaf0902d2 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -266,8 +266,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # first delete should give a 401 channel = self._delete_device(tok1, "dev2") self.assertEqual(channel.code, 401) - # there are no valid flows here! - self.assertEqual(channel.json_body["flows"], []) + # m.login.password UIA is permitted because the auth provider allows it, + # even though the localdb does not. + self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}]) session = channel.json_body["session"] mock_password_provider.check_password.assert_not_called() -- cgit 1.5.1