diff options
author | Richard van der Hoff <richard@matrix.org> | 2020-12-01 00:15:36 +0000 |
---|---|---|
committer | Richard van der Hoff <richard@matrix.org> | 2020-12-02 18:54:15 +0000 |
commit | 0bac276890567ef3a3fafd7f5b7b5cac91a1031b (patch) | |
tree | 676f10e3ed3c786906f01f7ab5d8d2a1d752ea3c | |
parent | Factor out FakeResponse from test_oidc (diff) | |
download | synapse-0bac276890567ef3a3fafd7f5b7b5cac91a1031b.tar.xz |
UIA: offer only available auth flows
During user-interactive auth, do not offer password auth to users with no password, nor SSO auth to users with no SSO. Fixes #7559.
Diffstat (limited to '')
-rw-r--r-- | synapse/handlers/auth.py | 58 | ||||
-rw-r--r-- | synapse/storage/databases/main/registration.py | 25 | ||||
-rw-r--r-- | synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql | 17 | ||||
-rw-r--r-- | tests/rest/client/v1/utils.py | 116 | ||||
-rw-r--r-- | tests/rest/client/v2_alpha/test_auth.py | 94 | ||||
-rw-r--r-- | tests/server.py | 1 |
6 files changed, 278 insertions, 33 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c7dc07008a..2e72298e05 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -193,9 +193,7 @@ class AuthHandler(BaseHandler): self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() self._password_enabled = hs.config.password_enabled - self._sso_enabled = ( - hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled - ) + self._password_localdb_enabled = hs.config.password_localdb_enabled # we keep this as a list despite the O(N^2) implication so that we can # keep PASSWORD first and avoid confusing clients which pick the first @@ -205,7 +203,7 @@ class AuthHandler(BaseHandler): # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = [] - if hs.config.password_localdb_enabled: + if self._password_localdb_enabled: login_types.append(LoginType.PASSWORD) for provider in self.password_providers: @@ -219,14 +217,6 @@ class AuthHandler(BaseHandler): self._supported_login_types = login_types - # Login types and UI Auth types have a heavy overlap, but are not - # necessarily identical. Login types have SSO (and other login types) - # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET. - ui_auth_types = login_types.copy() - if self._sso_enabled: - ui_auth_types.append(LoginType.SSO) - self._supported_ui_auth_types = ui_auth_types - # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. self._failed_uia_attempts_ratelimiter = Ratelimiter( @@ -339,7 +329,10 @@ class AuthHandler(BaseHandler): self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False) # build a list of supported flows - flows = [[login_type] for login_type in self._supported_ui_auth_types] + supported_ui_auth_types = await self._get_available_ui_auth_types( + requester.user + ) + flows = [[login_type] for login_type in supported_ui_auth_types] try: result, params, session_id = await self.check_ui_auth( @@ -351,7 +344,7 @@ class AuthHandler(BaseHandler): raise # find the completed login type - for login_type in self._supported_ui_auth_types: + for login_type in supported_ui_auth_types: if login_type not in result: continue @@ -367,6 +360,41 @@ class AuthHandler(BaseHandler): return params, session_id + async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: + """Get a list of the authentication types this user can use + """ + + ui_auth_types = set() + + # if the HS supports password auth, and the user has a non-null password, we + # support password auth + if self._password_localdb_enabled and self._password_enabled: + lookupres = await self._find_user_id_and_pwd_hash(user.to_string()) + if lookupres: + _, password_hash = lookupres + if password_hash: + ui_auth_types.add(LoginType.PASSWORD) + + # also allow auth from password providers + for provider in self.password_providers: + for t in provider.get_supported_login_types().keys(): + if t == LoginType.PASSWORD and not self._password_enabled: + continue + ui_auth_types.add(t) + + # if sso is enabled, allow the user to log in via SSO iff they have a mapping + # from sso to mxid. + if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled: + if await self.store.get_external_ids_by_user(user.to_string()): + ui_auth_types.add(LoginType.SSO) + + # Our CAS impl does not (yet) correctly register users in user_external_ids, + # so always offer that if it's available. + if self.hs.config.cas.cas_enabled: + ui_auth_types.add(LoginType.SSO) + + return ui_auth_types + def get_enabled_auth_types(self): """Return the enabled user-interactive authentication types @@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler): if result: return result - if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: + if login_type == LoginType.PASSWORD and self._password_localdb_enabled: known_login_type = True # we've already checked that there is a (valid) password field diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index fedb8a6c26..ff96c34c2e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="get_user_by_external_id", ) + async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]: + """Look up external ids for the given user + + Args: + mxid: the MXID to be looked up + + Returns: + Tuples of (auth_provider, external_id) + """ + res = await self.db_pool.simple_select_list( + table="user_external_ids", + keyvalues={"user_id": mxid}, + retcols=("auth_provider", "external_id"), + desc="get_external_ids_by_user", + ) + return [(r["auth_provider"], r["external_id"]) for r in res] + async def count_all_users(self): """Counts all users registered on the homeserver.""" @@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): "users_set_deactivated_flag", self._background_update_set_deactivated_flag ) + self.db_pool.updates.register_background_index_update( + "user_external_ids_user_id_idx", + index_name="user_external_ids_user_id_idx", + table="user_external_ids", + columns=["user_id"], + unique=False, + ) + async def _background_update_set_deactivated_flag(self, progress, batch_size): """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1 for each of them. diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql new file mode 100644 index 0000000000..8f5e65aa71 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql @@ -0,0 +1,17 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (5825, 'user_external_ids_user_id_idx', '{}'); diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 737c38c396..5a18af8d34 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -2,7 +2,7 @@ # Copyright 2014-2016 OpenMarket Ltd # Copyright 2017 Vector Creations Ltd # Copyright 2018-2019 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019-2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,17 +17,23 @@ # limitations under the License. import json +import re import time +import urllib.parse from typing import Any, Dict, Optional +from mock import patch + import attr from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership +from synapse.types import JsonDict from tests.server import FakeSite, make_request +from tests.test_utils import FakeResponse @attr.s @@ -344,3 +350,111 @@ class RestHelper: ) return channel.json_body + + def login_via_oidc(self, remote_user_id: str) -> JsonDict: + """Log in (as a new user) via OIDC + + Returns the result of the final token login. + + Requires that "oidc_config" in the homeserver config be set appropriately + (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a + "public_base_url". + + Also requires the login servlet and the OIDC callback resource to be mounted at + the normal places. + """ + client_redirect_url = "https://x" + + # first hit the redirect url (which will issue a cookie and state) + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + "/login/sso/redirect?redirectUrl=" + client_redirect_url, + ) + # that will redirect to the OIDC IdP, but we skip that and go straight + # back to synapse's OIDC callback resource. However, we do need the "state" + # param that synapse passes to the IdP via query params, and the cookie that + # synapse passes to the client. + assert channel.code == 302 + oauth_uri = channel.headers.getRawHeaders("Location")[0] + params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query) + redirect_uri = "%s?%s" % ( + urllib.parse.urlparse(params["redirect_uri"][0]).path, + urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}), + ) + cookies = {} + for h in channel.headers.getRawHeaders("Set-Cookie"): + parts = h.split(";") + k, v = parts[0].split("=", maxsplit=1) + cookies[k] = v + + # before we hit the callback uri, stub out some methods in the http client so + # that we don't have to handle full HTTPS requests. + + # (expected url, json response) pairs, in the order we expect them. + expected_requests = [ + # first we get a hit to the token endpoint, which we tell to return + # a dummy OIDC access token + ("https://issuer.test/token", {"access_token": "TEST"}), + # and then one to the user_info endpoint, which returns our remote user id. + ("https://issuer.test/userinfo", {"sub": remote_user_id}), + ] + + async def mock_req(method: str, uri: str, data=None, headers=None): + (expected_uri, resp_obj) = expected_requests.pop(0) + assert uri == expected_uri + resp = FakeResponse( + code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"), + ) + return resp + + with patch.object(self.hs.get_proxied_http_client(), "request", mock_req): + # now hit the callback URI with the right params and a made-up code + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "GET", + redirect_uri, + custom_headers=[ + ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items() + ], + ) + + # expect a confirmation page + assert channel.code == 200 + + # fish the matrix login token out of the body of the confirmation page + m = re.search( + 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,), + channel.result["body"].decode("utf-8"), + ) + assert m + login_token = m.group(1) + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token and device id. + _, channel = make_request( + self.hs.get_reactor(), + self.site, + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + assert channel.code == 200 + return channel.json_body + + +# an 'oidc_config' suitable for login_with_oidc. +TEST_OIDC_CONFIG = { + "enabled": True, + "discover": False, + "issuer": "https://issuer.test", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://z", + "token_endpoint": "https://issuer.test/token", + "userinfo_endpoint": "https://issuer.test/userinfo", + "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, +} diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 77246e478f..ac67a9de29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import List, Union from twisted.internet.defer import succeed @@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http.site import SynapseRequest from synapse.rest.client.v1 import login from synapse.rest.client.v2_alpha import auth, devices, register -from synapse.types import JsonDict +from synapse.rest.oidc import OIDCResource +from synapse.types import JsonDict, UserID from tests import unittest +from tests.rest.client.v1.utils import TEST_OIDC_CONFIG from tests.server import FakeChannel @@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase): register.register_servlets, ] + def default_config(self): + config = super().default_config() + + # we enable OIDC as a way of testing SSO flows + oidc_config = {} + oidc_config.update(TEST_OIDC_CONFIG) + oidc_config["allow_existing_users"] = True + + config["oidc_config"] = oidc_config + config["public_baseurl"] = "https://synapse.test" + return config + + def create_resource_dict(self): + resource_dict = super().create_resource_dict() + # mount the OIDC resource at /_synapse/oidc + resource_dict["/_synapse/oidc"] = OIDCResource(self.hs) + return resource_dict + def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) self.user_tok = self.login("test", self.user_pass) - def get_device_ids(self) -> List[str]: + def get_device_ids(self, access_token: str) -> List[str]: # Get the list of devices so one can be deleted. - request, channel = self.make_request( - "GET", "devices", access_token=self.user_tok, - ) # type: SynapseRequest, FakeChannel - - # Get the ID of the device. - self.assertEqual(request.code, 200) + _, channel = self.make_request("GET", "devices", access_token=access_token,) + self.assertEqual(channel.code, 200) return [d["device_id"] for d in channel.json_body["devices"]] def delete_device( - self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b"" + self, + access_token: str, + device: str, + expected_response: int, + body: Union[bytes, JsonDict] = b"", ) -> FakeChannel: """Delete an individual device.""" request, channel = self.make_request( - "DELETE", "devices/" + device, body, access_token=self.user_tok + "DELETE", "devices/" + device, body, access_token=access_token, ) # type: SynapseRequest, FakeChannel # Ensure the response is sane. @@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids()[0] + device_id = self.get_device_ids(self.user_tok)[0] # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(device_id, 401) + channel = self.delete_device(self.user_tok, device_id, 401) # Grab the session session = channel.json_body["session"] @@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( + self.user_tok, device_id, 200, { @@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids()[0] - channel = self.delete_device(device_id, 401) + device_id = self.get_device_ids(self.user_tok)[0] + channel = self.delete_device(self.user_tok, device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( + self.user_tok, device_id, 200, { @@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Create a second login. self.login("test", self.user_pass) - device_ids = self.get_device_ids() + device_ids = self.get_device_ids(self.user_tok) self.assertEqual(len(device_ids), 2) # Attempt to delete the first device. @@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase): # Create a second login. self.login("test", self.user_pass) - device_ids = self.get_device_ids() + device_ids = self.get_device_ids(self.user_tok) self.assertEqual(len(device_ids), 2) # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(device_ids[0], 401) + channel = self.delete_device(self.user_tok, device_ids[0], 401) # Grab the session session = channel.json_body["session"] @@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. self.delete_device( + self.user_tok, device_ids[1], 403, { @@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase): }, }, ) + + def test_does_not_offer_password_for_sso_user(self): + login_resp = self.helper.login_via_oidc("username") + user_tok = login_resp["access_token"] + device_id = login_resp["device_id"] + + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + channel = self.delete_device(user_tok, device_id, 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.sso"]}]) + + def test_does_not_offer_sso_for_password_user(self): + # now call the device deletion API: we should get the option to auth with SSO + # and not password. + device_ids = self.get_device_ids(self.user_tok) + channel = self.delete_device(self.user_tok, device_ids[0], 401) + + flows = channel.json_body["flows"] + self.assertEqual(flows, [{"stages": ["m.login.password"]}]) + + def test_offers_both_flows_for_upgraded_user(self): + """A user that had a password and then logged in with SSO should get both flows + """ + login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) + self.assertEqual(login_resp["user_id"], self.user) + + device_ids = self.get_device_ids(self.user_tok) + channel = self.delete_device(self.user_tok, device_ids[0], 401) + + flows = channel.json_body["flows"] + # we have no particular expectations of ordering here + self.assertIn({"stages": ["m.login.password"]}, flows) + self.assertIn({"stages": ["m.login.sso"]}, flows) + self.assertEqual(len(flows), 2) diff --git a/tests/server.py b/tests/server.py index eee970c43c..4faf32e335 100644 --- a/tests/server.py +++ b/tests/server.py @@ -259,6 +259,7 @@ def make_request( for k, v in custom_headers: req.requestHeaders.addRawHeader(k, v) + req.parseCookies() req.requestReceived(method, path, b"1.1") if await_result: |