summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-12-01 00:15:36 +0000
committerRichard van der Hoff <richard@matrix.org>2020-12-02 18:54:15 +0000
commit0bac276890567ef3a3fafd7f5b7b5cac91a1031b (patch)
tree676f10e3ed3c786906f01f7ab5d8d2a1d752ea3c
parentFactor out FakeResponse from test_oidc (diff)
downloadsynapse-0bac276890567ef3a3fafd7f5b7b5cac91a1031b.tar.xz
UIA: offer only available auth flows
During user-interactive auth, do not offer password auth to users with no
password, nor SSO auth to users with no SSO.

Fixes #7559.
Diffstat (limited to '')
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/storage/databases/main/registration.py25
-rw-r--r--synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql17
-rw-r--r--tests/rest/client/v1/utils.py116
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py94
-rw-r--r--tests/server.py1
6 files changed, 278 insertions, 33 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..2e72298e05 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = (
-            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
-        )
+        self._password_localdb_enabled = hs.config.password_localdb_enabled
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
 
         # start out by assuming PASSWORD is enabled; we will remove it later if not.
         login_types = []
-        if hs.config.password_localdb_enabled:
+        if self._password_localdb_enabled:
             login_types.append(LoginType.PASSWORD)
 
         for provider in self.password_providers:
@@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
 
         self._supported_login_types = login_types
 
-        # Login types and UI Auth types have a heavy overlap, but are not
-        # necessarily identical. Login types have SSO (and other login types)
-        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
-        ui_auth_types = login_types.copy()
-        if self._sso_enabled:
-            ui_auth_types.append(LoginType.SSO)
-        self._supported_ui_auth_types = ui_auth_types
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
         self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_ui_auth_types]
+        supported_ui_auth_types = await self._get_available_ui_auth_types(
+            requester.user
+        )
+        flows = [[login_type] for login_type in supported_ui_auth_types]
 
         try:
             result, params, session_id = await self.check_ui_auth(
@@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
             raise
 
         # find the completed login type
-        for login_type in self._supported_ui_auth_types:
+        for login_type in supported_ui_auth_types:
             if login_type not in result:
                 continue
 
@@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
 
         return params, session_id
 
+    async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+        """Get a list of the authentication types this user can use
+        """
+
+        ui_auth_types = set()
+
+        # if the HS supports password auth, and the user has a non-null password, we
+        # support password auth
+        if self._password_localdb_enabled and self._password_enabled:
+            lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+            if lookupres:
+                _, password_hash = lookupres
+                if password_hash:
+                    ui_auth_types.add(LoginType.PASSWORD)
+
+        # also allow auth from password providers
+        for provider in self.password_providers:
+            for t in provider.get_supported_login_types().keys():
+                if t == LoginType.PASSWORD and not self._password_enabled:
+                    continue
+                ui_auth_types.add(t)
+
+        # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+        # from sso to mxid.
+        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+            if await self.store.get_external_ids_by_user(user.to_string()):
+                ui_auth_types.add(LoginType.SSO)
+
+        # Our CAS impl does not (yet) correctly register users in user_external_ids,
+        # so always offer that if it's available.
+        if self.hs.config.cas.cas_enabled:
+            ui_auth_types.add(LoginType.SSO)
+
+        return ui_auth_types
+
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
 
@@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
             if result:
                 return result
 
-        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+        if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
             # we've already checked that there is a (valid) password field
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..ff96c34c2e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_user_by_external_id",
         )
 
+    async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+        """Look up external ids for the given user
+
+        Args:
+            mxid: the MXID to be looked up
+
+        Returns:
+            Tuples of (auth_provider, external_id)
+        """
+        res = await self.db_pool.simple_select_list(
+            table="user_external_ids",
+            keyvalues={"user_id": mxid},
+            retcols=("auth_provider", "external_id"),
+            desc="get_external_ids_by_user",
+        )
+        return [(r["auth_provider"], r["external_id"]) for r in res]
+
     async def count_all_users(self):
         """Counts all users registered on the homeserver."""
 
@@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             "users_set_deactivated_flag", self._background_update_set_deactivated_flag
         )
 
+        self.db_pool.updates.register_background_index_update(
+            "user_external_ids_user_id_idx",
+            index_name="user_external_ids_user_id_idx",
+            table="user_external_ids",
+            columns=["user_id"],
+            unique=False,
+        )
+
     async def _background_update_set_deactivated_flag(self, progress, batch_size):
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+  (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..5a18af8d34 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
 # limitations under the License.
 
 import json
+import re
 import time
+import urllib.parse
 from typing import Any, Dict, Optional
 
+from mock import patch
+
 import attr
 
 from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.types import JsonDict
 
 from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
 
 
 @attr.s
@@ -344,3 +350,111 @@ class RestHelper:
         )
 
         return channel.json_body
+
+    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+        """Log in (as a new user) via OIDC
+
+        Returns the result of the final token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+        client_redirect_url = "https://x"
+
+        # first hit the redirect url (which will issue a cookie and state)
+        _, channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+        )
+        # that will redirect to the OIDC IdP, but we skip that and go straight
+        # back to synapse's OIDC callback resource. However, we do need the "state"
+        # param that synapse passes to the IdP via query params, and the cookie that
+        # synapse passes to the client.
+        assert channel.code == 302
+        oauth_uri = channel.headers.getRawHeaders("Location")[0]
+        params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+        redirect_uri = "%s?%s" % (
+            urllib.parse.urlparse(params["redirect_uri"][0]).path,
+            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+        )
+        cookies = {}
+        for h in channel.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
+        # before we hit the callback uri, stub out some methods in the http client so
+        # that we don't have to handle full HTTPS requests.
+
+        # (expected url, json response) pairs, in the order we expect them.
+        expected_requests = [
+            # first we get a hit to the token endpoint, which we tell to return
+            # a dummy OIDC access token
+            ("https://issuer.test/token", {"access_token": "TEST"}),
+            # and then one to the user_info endpoint, which returns our remote user id.
+            ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+        ]
+
+        async def mock_req(method: str, uri: str, data=None, headers=None):
+            (expected_uri, resp_obj) = expected_requests.pop(0)
+            assert uri == expected_uri
+            resp = FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+            )
+            return resp
+
+        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+            # now hit the callback URI with the right params and a made-up code
+            _, channel = make_request(
+                self.hs.get_reactor(),
+                self.site,
+                "GET",
+                redirect_uri,
+                custom_headers=[
+                    ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+                ],
+            )
+
+        # expect a confirmation page
+        assert channel.code == 200
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.result["body"].decode("utf-8"),
+        )
+        assert m
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
+        _, channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        assert channel.code == 200
+        return channel.json_body
+
+
+# an 'oidc_config' suitable for login_with_oidc.
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "discover": False,
+    "issuer": "https://issuer.test",
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["profile"],
+    "authorization_endpoint": "https://z",
+    "token_endpoint": "https://issuer.test/token",
+    "userinfo_endpoint": "https://issuer.test/userinfo",
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..ac67a9de29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from typing import List, Union
 
 from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.http.site import SynapseRequest
 from synapse.rest.client.v1 import login
 from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
 
 from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 
 
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
         register.register_servlets,
     ]
 
+    def default_config(self):
+        config = super().default_config()
+
+        # we enable OIDC as a way of testing SSO flows
+        oidc_config = {}
+        oidc_config.update(TEST_OIDC_CONFIG)
+        oidc_config["allow_existing_users"] = True
+
+        config["oidc_config"] = oidc_config
+        config["public_baseurl"] = "https://synapse.test"
+        return config
+
+    def create_resource_dict(self):
+        resource_dict = super().create_resource_dict()
+        # mount the OIDC resource at /_synapse/oidc
+        resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+        return resource_dict
+
     def prepare(self, reactor, clock, hs):
         self.user_pass = "pass"
         self.user = self.register_user("test", self.user_pass)
         self.user_tok = self.login("test", self.user_pass)
 
-    def get_device_ids(self) -> List[str]:
+    def get_device_ids(self, access_token: str) -> List[str]:
         # Get the list of devices so one can be deleted.
-        request, channel = self.make_request(
-            "GET", "devices", access_token=self.user_tok,
-        )  # type: SynapseRequest, FakeChannel
-
-        # Get the ID of the device.
-        self.assertEqual(request.code, 200)
+        _, channel = self.make_request("GET", "devices", access_token=access_token,)
+        self.assertEqual(channel.code, 200)
         return [d["device_id"] for d in channel.json_body["devices"]]
 
     def delete_device(
-        self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+        self,
+        access_token: str,
+        device: str,
+        expected_response: int,
+        body: Union[bytes, JsonDict] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
         request, channel = self.make_request(
-            "DELETE", "devices/" + device, body, access_token=self.user_tok
+            "DELETE", "devices/" + device, body, access_token=access_token,
         )  # type: SynapseRequest, FakeChannel
 
         # Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """
         Test user interactive authentication outside of registration.
         """
-        device_id = self.get_device_ids()[0]
+        device_id = self.get_device_ids(self.user_tok)[0]
 
         # Attempt to delete this device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_id, 401)
+        channel = self.delete_device(self.user_tok, device_id, 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # Make another request providing the UI auth flow.
         self.delete_device(
+            self.user_tok,
             device_id,
             200,
             {
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
         UIA - check that still works.
         """
 
-        device_id = self.get_device_ids()[0]
-        channel = self.delete_device(device_id, 401)
+        device_id = self.get_device_ids(self.user_tok)[0]
+        channel = self.delete_device(self.user_tok, device_id, 401)
         session = channel.json_body["session"]
 
         # Make another request providing the UI auth flow.
         self.delete_device(
+            self.user_tok,
             device_id,
             200,
             {
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Create a second login.
         self.login("test", self.user_pass)
 
-        device_ids = self.get_device_ids()
+        device_ids = self.get_device_ids(self.user_tok)
         self.assertEqual(len(device_ids), 2)
 
         # Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Create a second login.
         self.login("test", self.user_pass)
 
-        device_ids = self.get_device_ids()
+        device_ids = self.get_device_ids(self.user_tok)
         self.assertEqual(len(device_ids), 2)
 
         # Attempt to delete the first device.
         # Returns a 401 as per the spec
-        channel = self.delete_device(device_ids[0], 401)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
 
         # Grab the session
         session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
         # Make another request providing the UI auth flow, but try to delete the
         # second device. This results in an error.
         self.delete_device(
+            self.user_tok,
             device_ids[1],
             403,
             {
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
                 },
             },
         )
+
+    def test_does_not_offer_password_for_sso_user(self):
+        login_resp = self.helper.login_via_oidc("username")
+        user_tok = login_resp["access_token"]
+        device_id = login_resp["device_id"]
+
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        channel = self.delete_device(user_tok, device_id, 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+    def test_does_not_offer_sso_for_password_user(self):
+        # now call the device deletion API: we should get the option to auth with SSO
+        # and not password.
+        device_ids = self.get_device_ids(self.user_tok)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+        flows = channel.json_body["flows"]
+        self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+    def test_offers_both_flows_for_upgraded_user(self):
+        """A user that had a password and then logged in with SSO should get both flows
+        """
+        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        device_ids = self.get_device_ids(self.user_tok)
+        channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+        flows = channel.json_body["flows"]
+        # we have no particular expectations of ordering here
+        self.assertIn({"stages": ["m.login.password"]}, flows)
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+        self.assertEqual(len(flows), 2)
diff --git a/tests/server.py b/tests/server.py
index eee970c43c..4faf32e335 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -259,6 +259,7 @@ def make_request(
         for k, v in custom_headers:
             req.requestHeaders.addRawHeader(k, v)
 
+    req.parseCookies()
     req.requestReceived(method, path, b"1.1")
 
     if await_result: