summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_auth.py390
-rw-r--r--tests/rest/client/utils.py55
-rw-r--r--tests/server.py6
-rw-r--r--tests/test_utils/oidc.py27
4 files changed, 448 insertions, 30 deletions
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index ebf653d018..847294dc8e 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Tuple, Union
 
@@ -21,7 +22,7 @@ from twisted.web.resource import Resource
 
 import synapse.rest.admin
 from synapse.api.constants import ApprovalNoticeMedium, LoginType
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, SynapseError
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
 from synapse.rest.client import account, auth, devices, login, logout, register
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -32,8 +33,8 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
 from tests.unittest import override_config, skip_unless
 
 
@@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
             {"refresh_token": refresh_token},
         )
 
-    def is_access_token_valid(self, access_token: str) -> bool:
-        """
-        Checks whether an access token is valid, returning whether it is or not.
-        """
-        code = self.make_request(
-            "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
-        ).code
-
-        # Either 200 or 401 is what we get back; anything else is a bug.
-        assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
-        return code == HTTPStatus.OK
-
     def test_login_issue_refresh_token(self) -> None:
         """
         A login response should include a refresh_token only if asked.
@@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         self.reactor.advance(59.0)
 
         # Both tokens should still be valid.
-        self.assertTrue(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 61 s (just past 1 minute, the time of expiry)
         self.reactor.advance(2.0)
 
         # Only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 599 s (just shy of 10 minutes, the time of expiry)
         self.reactor.advance(599.0 - 61.0)
 
         # It's still the case that only the non-refreshable token is still valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
 
         # Advance to 601 s (just past 10 minutes, the time of expiry)
         self.reactor.advance(2.0)
 
         # Now neither token is valid.
-        self.assertFalse(self.is_access_token_valid(refreshable_access_token))
-        self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+        self.helper.whoami(
+            refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
+        self.helper.whoami(
+            nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+        )
 
     @override_config(
         {"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
         # and no refresh token
         self.assertEqual(_table_length("access_tokens"), 0)
         self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+    id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+    """Sample OIDC provider config used in backchannel logout tests.
+
+    Args:
+        id: IDP ID for this provider
+        with_localpart_template: Set to `true` to have a default localpart_template in
+            the `user_mapping_provider` config and skip the user mapping session
+        **kwargs: rest of the config
+
+    Returns:
+        A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+        the HS config
+    """
+    config: Dict[str, Any] = {
+        "idp_id": id,
+        "idp_name": id,
+        "issuer": TEST_OIDC_ISSUER,
+        "client_id": "test-client-id",
+        "client_secret": "test-client-secret",
+        "scopes": ["openid"],
+    }
+
+    if with_localpart_template:
+        config["user_mapping_provider"] = {
+            "config": {"localpart_template": "{{ user.sub }}"}
+        }
+    else:
+        config["user_mapping_provider"] = {"config": {}}
+
+    config.update(kwargs)
+
+    return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+    servlets = [
+        account.register_servlets,
+        login.register_servlets,
+    ]
+
+    def default_config(self) -> Dict[str, Any]:
+        config = super().default_config()
+
+        # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+        # False, so synapse will see the requested uri as http://..., so using http in
+        # the public_baseurl stops Synapse trying to redirect to https.
+        config["public_baseurl"] = "http://synapse.test"
+
+        return config
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        resource_dict = super().create_resource_dict()
+        resource_dict.update(build_synapse_client_resource_tree(self.hs))
+        return resource_dict
+
+    def submit_logout_token(self, logout_token: str) -> FakeChannel:
+        return self.make_request(
+            "POST",
+            "/_synapse/client/oidc/backchannel_logout",
+            content=f"logout_token={logout_token}",
+            content_is_form=True,
+        )
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_simple_logout(self) -> None:
+        """
+        Receiving a logout token should logout the user
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        self.assertNotEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = fake_oidc_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = fake_oidc_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_login(self) -> None:
+        """
+        It should revoke login tokens when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # expect a confirmation page
+        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now try to exchange the login token
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        # It should have failed
+        self.assertEqual(channel.code, 403)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=False,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_logout_during_mapping(self) -> None:
+        """
+        It should stop ongoing user mapping session when receiving a logout token
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        # Get an authentication, and logout before submitting the logout token
+        client_redirect_url = "https://x"
+        userinfo = {"sub": user}
+        channel, grant = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=True,
+        )
+
+        # Expect a user mapping page
+        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+        # We should have a user_mapping_session cookie
+        cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+        assert cookie_headers
+        cookies: Dict[str, str] = {}
+        for h in cookie_headers:
+            key, value = h.split(";")[0].split("=", maxsplit=1)
+            cookies[key] = value
+
+        user_mapping_session_id = cookies["username_mapping_session"]
+
+        # Getting that session should not raise
+        session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+        self.assertIsNotNone(session)
+
+        # Submit a logout
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        # Now it should raise
+        with self.assertRaises(SynapseError):
+            self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=False,
+                )
+            ]
+        }
+    )
+    def test_disabled(self) -> None:
+        """
+        Receiving a logout token should do nothing if it is disabled in the config
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=True
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    id="oidc",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                )
+            ]
+        }
+    )
+    def test_no_sid(self) -> None:
+        """
+        Receiving a logout token without `sid` during the login should do nothing
+        """
+        fake_oidc_server = self.helper.fake_oidc_server()
+        user = "john"
+
+        login_resp, grant = self.helper.login_via_oidc(
+            fake_oidc_server, user, with_sid=False
+        )
+        access_token: str = login_resp["access_token"]
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out shouldn't work
+        logout_token = fake_oidc_server.generate_logout_token(grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 400)
+
+        # And the token should still be valid
+        self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+    @override_config(
+        {
+            "oidc_providers": [
+                oidc_config(
+                    "first",
+                    issuer="https://first-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+                oidc_config(
+                    "second",
+                    issuer="https://second-issuer.com/",
+                    with_localpart_template=True,
+                    backchannel_logout_enabled=True,
+                ),
+            ]
+        }
+    )
+    def test_multiple_providers(self) -> None:
+        """
+        It should be able to distinguish login tokens from two different IdPs
+        """
+        first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+        second_server = self.helper.fake_oidc_server(
+            issuer="https://second-issuer.com/"
+        )
+        user = "john"
+
+        login_resp, first_grant = self.helper.login_via_oidc(
+            first_server, user, with_sid=True, idp_id="oidc-first"
+        )
+        first_access_token: str = login_resp["access_token"]
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+        login_resp, second_grant = self.helper.login_via_oidc(
+            second_server, user, with_sid=True, idp_id="oidc-second"
+        )
+        second_access_token: str = login_resp["access_token"]
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # `sid` in the fake providers are generated by a counter, so the first grant of
+        # each provider should give the same SID
+        self.assertEqual(first_grant.sid, second_grant.sid)
+        self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+        # Logging out of the first session
+        logout_token = first_server.generate_logout_token(first_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+        # Logging out of the second session
+        logout_token = second_server.generate_logout_token(second_grant)
+        channel = self.submit_logout_token(logout_token)
+        self.assertEqual(channel.code, 200)
+
+        self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 967d229223..706399fae5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -553,6 +553,34 @@ class RestHelper:
 
         return channel.json_body
 
+    def whoami(
+        self,
+        access_token: str,
+        expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
+    ) -> JsonDict:
+        """Perform a 'whoami' request, which can be a quick way to check for access
+        token validity
+
+        Args:
+            access_token: The user token to use during the request
+            expect_code: The return code to expect from attempting the whoami request
+        """
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "account/whoami",
+            access_token=access_token,
+        )
+
+        assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
+            expect_code,
+            channel.code,
+            channel.result["body"],
+        )
+
+        return channel.json_body
+
     def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
         """Create a ``FakeOidcServer``.
 
@@ -572,6 +600,7 @@ class RestHelper:
         fake_server: FakeOidcServer,
         remote_user_id: str,
         with_sid: bool = False,
+        idp_id: Optional[str] = None,
         expected_status: int = 200,
     ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
         """Log in (as a new user) via OIDC
@@ -588,7 +617,11 @@ class RestHelper:
         client_redirect_url = "https://x"
         userinfo = {"sub": remote_user_id}
         channel, grant = self.auth_via_oidc(
-            fake_server, userinfo, client_redirect_url, with_sid=with_sid
+            fake_server,
+            userinfo,
+            client_redirect_url,
+            with_sid=with_sid,
+            idp_id=idp_id,
         )
 
         # expect a confirmation page
@@ -623,6 +656,7 @@ class RestHelper:
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
         with_sid: bool = False,
+        idp_id: Optional[str] = None,
     ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Perform an OIDC authentication flow via a mock OIDC provider.
 
@@ -648,6 +682,7 @@ class RestHelper:
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
                 of the UI auth.
             with_sid: if True, generates a random `sid` (OIDC session ID)
+            idp_id: if set, explicitely chooses one specific IDP
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -665,7 +700,9 @@ class RestHelper:
                 oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
             else:
                 # otherwise, hit the login redirect endpoint
-                oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+                oauth_uri = self.initiate_sso_login(
+                    client_redirect_url, cookies, idp_id=idp_id
+                )
 
         # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
@@ -742,7 +779,10 @@ class RestHelper:
         return channel, grant
 
     def initiate_sso_login(
-        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+        self,
+        client_redirect_url: Optional[str],
+        cookies: MutableMapping[str, str],
+        idp_id: Optional[str] = None,
     ) -> str:
         """Make a request to the login-via-sso redirect endpoint, and return the target
 
@@ -753,6 +793,7 @@ class RestHelper:
             client_redirect_url: the client redirect URL to pass to the login redirect
                 endpoint
             cookies: any cookies returned will be added to this dict
+            idp_id: if set, explicitely chooses one specific IDP
 
         Returns:
             the URI that the client gets redirected to (ie, the SSO server)
@@ -761,6 +802,12 @@ class RestHelper:
         if client_redirect_url:
             params["redirectUrl"] = client_redirect_url
 
+        uri = "/_matrix/client/r0/login/sso/redirect"
+        if idp_id is not None:
+            uri = f"{uri}/{idp_id}"
+
+        uri = f"{uri}?{urllib.parse.urlencode(params)}"
+
         # hit the redirect url (which should redirect back to the redirect url. This
         # is the easiest way of figuring out what the Host header ought to be set to
         # to keep Synapse happy.
@@ -768,7 +815,7 @@ class RestHelper:
             self.hs.get_reactor(),
             self.site,
             "GET",
-            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+            uri,
         )
         assert channel.code == 302
 
diff --git a/tests/server.py b/tests/server.py
index 8b1d186219..b1730fcc8d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -362,6 +362,12 @@ def make_request(
     # Twisted expects to be at the end of the content when parsing the request.
     req.content.seek(0, SEEK_END)
 
+    # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
+    # bodies if the Content-Length header is missing
+    req.requestHeaders.addRawHeader(
+        b"Content-Length", str(len(content)).encode("ascii")
+    )
+
     if access_token:
         req.requestHeaders.addRawHeader(
             b"Authorization", b"Bearer " + access_token.encode("ascii")
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index de134bbc89..1461d23ee8 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -51,6 +51,8 @@ class FakeOidcServer:
     get_userinfo_handler: Mock
     post_token_handler: Mock
 
+    sid_counter: int = 0
+
     def __init__(self, clock: Clock, issuer: str):
         from authlib.jose import ECKey, KeySet
 
@@ -146,7 +148,7 @@ class FakeOidcServer:
         return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
 
     def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
-        now = self._clock.time()
+        now = int(self._clock.time())
         id_token = {
             **grant.userinfo,
             "iss": self.issuer,
@@ -166,6 +168,26 @@ class FakeOidcServer:
 
         return self._sign(id_token)
 
+    def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = int(self._clock.time())
+        logout_token = {
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "jti": random_string(10),
+            "events": {
+                "http://schemas.openid.net/event/backchannel-logout": {},
+            },
+        }
+
+        if grant.sid is not None:
+            logout_token["sid"] = grant.sid
+
+        if "sub" in grant.userinfo:
+            logout_token["sub"] = grant.userinfo["sub"]
+
+        return self._sign(logout_token)
+
     def id_token_override(self, overrides: dict):
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
@@ -183,7 +205,8 @@ class FakeOidcServer:
         code = random_string(10)
         sid = None
         if with_sid:
-            sid = random_string(10)
+            sid = str(self.sid_counter)
+            self.sid_counter += 1
 
         grant = FakeAuthorizationGrant(
             userinfo=userinfo,