summary refs log tree commit diff
path: root/tests/rest/client/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_auth.py')
-rw-r--r--tests/rest/client/test_auth.py422
1 files changed, 389 insertions, 33 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 090cef5216..847294dc8e 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Tuple, Union
 
@@ -21,7 +22,7 @@ from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, SynapseError
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -32,8 +33,8 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
 from tests.unittest import override_config, skip_unless
 
 
@@ -465,9 +466,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
           * checking that the original operation succeeds
         """
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
         remote_user_id = UserID.from_string(self.user).localpart
-        login_resp = self.helper.login_via_oidc(remote_user_id)
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
@@ -481,8 +484,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # run the UIA-via-SSO flow
         session_id = channel.json_body["session"]
-        channel = self.helper.auth_via_oidc(
-            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
         )
 
         # that should serve a confirmation page
@@ -499,7 +502,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self) -> None:
-        login_resp = self.helper.login_via_oidc("username")
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
 
@@ -522,7 +526,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         channel = self.delete_device(
@@ -539,8 +546,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
@@ -553,8 +565,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session_id = channel.json_body["session"]
 
         # do the OIDC auth, but auth as the wrong user
-        channel = self.helper.auth_via_oidc(
-            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
         )
 
         # that should return a failure message
@@ -584,7 +596,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Tests that if we register a user via SSO while requiring approval for new
         accounts, we still raise the correct error before logging the user in.
         """
-        login_resp = self.helper.login_via_oidc("username", expected_status=403)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, "username", expected_status=403
+        )
 
         self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
         self.assertEqual(
@@ -624,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": refresh_token},
         )
 
-    def is_access_token_valid(self, access_token: str) -> bool:
-        """
-        Checks whether an access token is valid, returning whether it is or not.
-        """
-        code = self.make_request(
-            "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
-        ).code
-
-        # Either 200 or 401 is what we get back; anything else is a bug.
-        assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
-        return code == HTTPStatus.OK
-
     def test_login_issue_refresh_token(self) -> None:
         """
         A login response should include a refresh_token only if asked.
@@ -833,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.reactor.advance(59.0)
 
         # Both tokens should still be valid.
-        self.assertTrue(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 61 s (just past 1 minute, the time of expiry)
         self.reactor.advance(2.0)
 
         # Only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 599 s (just shy of 10 minutes, the time of expiry)
         self.reactor.advance(599.0 - 61.0)
 
         # It's still the case that only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 601 s (just past 10 minutes, the time of expiry)
         self.reactor.advance(2.0)
 
         # Now neither token is valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(
+            nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
 
     @override_config(
         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -1151,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         # and no refresh token
         self.assertEqual(_table_length("access_tokens"), 0)
         self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+    id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+    """Sample OIDC provider config used in backchannel logout tests.
+
+    Args:
+        id: IDP ID for this provider
+        with_localpart_template: Set to `true` to have a default localpart_template in
+            the `user_mapping_provider` config and skip the user mapping session
+        **kwargs: rest of the config
+
+    Returns:
+        A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+        the HS config
+    """
+    config: Dict[str, Any] = {
+        "idp_id": id,
+        "idp_name": id,
+        "issuer": TEST_OIDC_ISSUER,
+        "client_id": "test-client-id",
+        "client_secret": "test-client-secret",
+        "scopes": ["openid"],
+    }
+
+    if with_localpart_template:
+        config["user_mapping_provider"] = {
+            "config": {"localpart_template": "{{ user.sub }}"}
+        }
+    else:
+        config["user_mapping_provider"] = {"config": {}}
+
+    config.update(kwargs)
+
+    return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+        # False, so synapse will see the requested uri as http://..., so using http in
+        # the public_baseurl stops Synapse trying to redirect to https.
+        config["public_baseurl"] = "http://synapse.test"
+
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resource_dict = super().create_resource_dict()
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
+        return resource_dict
+
+    def submit_logout_token(self, logout_token: str) -> FakeChannel:
+        return self.make_request(
+            "POST",
+            "/_synapse/client/oidc/backchannel_logout",
+            content=f"logout_token={logout_token}",
+            content_is_form=True,
+        )
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_simple_logout(self) -> None:
+        """
+        Receiving a logout token should logout the user
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        self.assertNotEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = fake_oidc_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = fake_oidc_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_login(self) -> None:
+        """
+        It should revoke login tokens when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # expect a confirmation page
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now try to exchange the login token
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        # It should have failed
+        self.assertEqual(channel.code, 403)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=False,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_mapping(self) -> None:
+        """
+        It should stop ongoing user mapping session when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # Expect a user mapping page
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+        # We should have a user_mapping_session cookie
+        cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+        assert cookie_headers
+        cookies: Dict[str, str] = {}
+        for h in cookie_headers:
+            key, value = h.split(";")[0].split("=", maxsplit=1)
+            cookies[key] = value
+
+        user_mapping_session_id = cookies["username_mapping_session"]
+
+        # Getting that session should not raise
+        session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+        self.assertIsNotNone(session)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now it should raise
+        with self.assertRaises(SynapseError):
+            self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=False,
+                )
+            ]
+        }
+    )
+    def test_disabled(self) -> None:
+        """
+        Receiving a logout token should do nothing if it is disabled in the config
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_no_sid(self) -> None:
+        """
+        Receiving a logout token without `sid` during the login should do nothing
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=False
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    "first",
+                    issuer="https://first-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+                oidc_config(
+                    "second",
+                    issuer="https://second-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+            ]
+        }
+    )
+    def test_multiple_providers(self) -> None:
+        """
+        It should be able to distinguish login tokens from two different IdPs
+        """
+        first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+        second_server = self.helper.fake_oidc_server(
+            issuer="https://second-issuer.com/"
+        )
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            first_server, user, with_sid=True, idp_id="oidc-first"
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            second_server, user, with_sid=True, idp_id="oidc-second"
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # `sid` in the fake providers are generated by a counter, so the first grant of
+        # each provider should give the same SID
+        self.assertEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = first_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = second_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)