diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index ebf653d018..847294dc8e 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import re
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -21,7 +22,7 @@ from twisted.web.resource import Resource
import synapse.rest.admin
from synapse.api.constants import ApprovalNoticeMedium, LoginType
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, SynapseError
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
@@ -32,8 +33,8 @@ from synapse.util import Clock
from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
-from tests.rest.client.utils import TEST_OIDC_CONFIG
-from tests.server import FakeChannel
+from tests.rest.client.utils import TEST_OIDC_CONFIG, TEST_OIDC_ISSUER
+from tests.server import FakeChannel, make_request
from tests.unittest import override_config, skip_unless
@@ -638,19 +639,6 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
{"refresh_token": refresh_token},
)
- def is_access_token_valid(self, access_token: str) -> bool:
- """
- Checks whether an access token is valid, returning whether it is or not.
- """
- code = self.make_request(
- "GET", "/_matrix/client/v3/account/whoami", access_token=access_token
- ).code
-
- # Either 200 or 401 is what we get back; anything else is a bug.
- assert code in {HTTPStatus.OK, HTTPStatus.UNAUTHORIZED}
-
- return code == HTTPStatus.OK
-
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
@@ -847,29 +835,37 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
self.reactor.advance(59.0)
# Both tokens should still be valid.
- self.assertTrue(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(refreshable_access_token, expect_code=HTTPStatus.OK)
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 61 s (just past 1 minute, the time of expiry)
self.reactor.advance(2.0)
# Only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 599 s (just shy of 10 minutes, the time of expiry)
self.reactor.advance(599.0 - 61.0)
# It's still the case that only the non-refreshable token is still valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertTrue(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(nonrefreshable_access_token, expect_code=HTTPStatus.OK)
# Advance to 601 s (just past 10 minutes, the time of expiry)
self.reactor.advance(2.0)
# Now neither token is valid.
- self.assertFalse(self.is_access_token_valid(refreshable_access_token))
- self.assertFalse(self.is_access_token_valid(nonrefreshable_access_token))
+ self.helper.whoami(
+ refreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
+ self.helper.whoami(
+ nonrefreshable_access_token, expect_code=HTTPStatus.UNAUTHORIZED
+ )
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
@@ -1165,3 +1161,349 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
# and no refresh token
self.assertEqual(_table_length("access_tokens"), 0)
self.assertEqual(_table_length("refresh_tokens"), 0)
+
+
+def oidc_config(
+ id: str, with_localpart_template: bool, **kwargs: Any
+) -> Dict[str, Any]:
+ """Sample OIDC provider config used in backchannel logout tests.
+
+ Args:
+ id: IDP ID for this provider
+ with_localpart_template: Set to `true` to have a default localpart_template in
+ the `user_mapping_provider` config and skip the user mapping session
+ **kwargs: rest of the config
+
+ Returns:
+ A dict suitable for the `oidc_config` or the `oidc_providers[]` parts of
+ the HS config
+ """
+ config: Dict[str, Any] = {
+ "idp_id": id,
+ "idp_name": id,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ }
+
+ if with_localpart_template:
+ config["user_mapping_provider"] = {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ }
+ else:
+ config["user_mapping_provider"] = {"config": {}}
+
+ config.update(kwargs)
+
+ return config
+
+
+@skip_unless(HAS_OIDC, "Requires OIDC")
+class OidcBackchannelLogoutTests(unittest.HomeserverTestCase):
+ servlets = [
+ account.register_servlets,
+ login.register_servlets,
+ ]
+
+ def default_config(self) -> Dict[str, Any]:
+ config = super().default_config()
+
+ # public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
+ # False, so synapse will see the requested uri as http://..., so using http in
+ # the public_baseurl stops Synapse trying to redirect to https.
+ config["public_baseurl"] = "http://synapse.test"
+
+ return config
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ resource_dict = super().create_resource_dict()
+ resource_dict.update(build_synapse_client_resource_tree(self.hs))
+ return resource_dict
+
+ def submit_logout_token(self, logout_token: str) -> FakeChannel:
+ return self.make_request(
+ "POST",
+ "/_synapse/client/oidc/backchannel_logout",
+ content=f"logout_token={logout_token}",
+ content_is_form=True,
+ )
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_simple_logout(self) -> None:
+ """
+ Receiving a logout token should logout the user
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ self.assertNotEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = fake_oidc_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = fake_oidc_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_login(self) -> None:
+ """
+ It should revoke login tokens when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # expect a confirmation page
+ self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.text_body,
+ )
+ assert m, channel.text_body
+ login_token = m.group(1)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now try to exchange the login token
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ # It should have failed
+ self.assertEqual(channel.code, 403)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=False,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_logout_during_mapping(self) -> None:
+ """
+ It should stop ongoing user mapping session when receiving a logout token
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ # Get an authentication, and logout before submitting the logout token
+ client_redirect_url = "https://x"
+ userinfo = {"sub": user}
+ channel, grant = self.helper.auth_via_oidc(
+ fake_oidc_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=True,
+ )
+
+ # Expect a user mapping page
+ self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+
+ # We should have a user_mapping_session cookie
+ cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
+ assert cookie_headers
+ cookies: Dict[str, str] = {}
+ for h in cookie_headers:
+ key, value = h.split(";")[0].split("=", maxsplit=1)
+ cookies[key] = value
+
+ user_mapping_session_id = cookies["username_mapping_session"]
+
+ # Getting that session should not raise
+ session = self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+ self.assertIsNotNone(session)
+
+ # Submit a logout
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ # Now it should raise
+ with self.assertRaises(SynapseError):
+ self.hs.get_sso_handler().get_mapping_session(user_mapping_session_id)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=False,
+ )
+ ]
+ }
+ )
+ def test_disabled(self) -> None:
+ """
+ Receiving a logout token should do nothing if it is disabled in the config
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=True
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ id="oidc",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ )
+ ]
+ }
+ )
+ def test_no_sid(self) -> None:
+ """
+ Receiving a logout token without `sid` during the login should do nothing
+ """
+ fake_oidc_server = self.helper.fake_oidc_server()
+ user = "john"
+
+ login_resp, grant = self.helper.login_via_oidc(
+ fake_oidc_server, user, with_sid=False
+ )
+ access_token: str = login_resp["access_token"]
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out shouldn't work
+ logout_token = fake_oidc_server.generate_logout_token(grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 400)
+
+ # And the token should still be valid
+ self.helper.whoami(access_token, expect_code=HTTPStatus.OK)
+
+ @override_config(
+ {
+ "oidc_providers": [
+ oidc_config(
+ "first",
+ issuer="https://first-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ oidc_config(
+ "second",
+ issuer="https://second-issuer.com/",
+ with_localpart_template=True,
+ backchannel_logout_enabled=True,
+ ),
+ ]
+ }
+ )
+ def test_multiple_providers(self) -> None:
+ """
+ It should be able to distinguish login tokens from two different IdPs
+ """
+ first_server = self.helper.fake_oidc_server(issuer="https://first-issuer.com/")
+ second_server = self.helper.fake_oidc_server(
+ issuer="https://second-issuer.com/"
+ )
+ user = "john"
+
+ login_resp, first_grant = self.helper.login_via_oidc(
+ first_server, user, with_sid=True, idp_id="oidc-first"
+ )
+ first_access_token: str = login_resp["access_token"]
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.OK)
+
+ login_resp, second_grant = self.helper.login_via_oidc(
+ second_server, user, with_sid=True, idp_id="oidc-second"
+ )
+ second_access_token: str = login_resp["access_token"]
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # `sid` in the fake providers are generated by a counter, so the first grant of
+ # each provider should give the same SID
+ self.assertEqual(first_grant.sid, second_grant.sid)
+ self.assertEqual(first_grant.userinfo["sub"], second_grant.userinfo["sub"])
+
+ # Logging out of the first session
+ logout_token = first_server.generate_logout_token(first_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(first_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.OK)
+
+ # Logging out of the second session
+ logout_token = second_server.generate_logout_token(second_grant)
+ channel = self.submit_logout_token(logout_token)
+ self.assertEqual(channel.code, 200)
+
+ self.helper.whoami(second_access_token, expect_code=HTTPStatus.UNAUTHORIZED)
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 967d229223..706399fae5 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -553,6 +553,34 @@ class RestHelper:
return channel.json_body
+ def whoami(
+ self,
+ access_token: str,
+ expect_code: Literal[HTTPStatus.OK, HTTPStatus.UNAUTHORIZED] = HTTPStatus.OK,
+ ) -> JsonDict:
+ """Perform a 'whoami' request, which can be a quick way to check for access
+ token validity
+
+ Args:
+ access_token: The user token to use during the request
+ expect_code: The return code to expect from attempting the whoami request
+ """
+ channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "account/whoami",
+ access_token=access_token,
+ )
+
+ assert channel.code == expect_code, "Exepcted: %d, got %d, resp: %r" % (
+ expect_code,
+ channel.code,
+ channel.result["body"],
+ )
+
+ return channel.json_body
+
def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
"""Create a ``FakeOidcServer``.
@@ -572,6 +600,7 @@ class RestHelper:
fake_server: FakeOidcServer,
remote_user_id: str,
with_sid: bool = False,
+ idp_id: Optional[str] = None,
expected_status: int = 200,
) -> Tuple[JsonDict, FakeAuthorizationGrant]:
"""Log in (as a new user) via OIDC
@@ -588,7 +617,11 @@ class RestHelper:
client_redirect_url = "https://x"
userinfo = {"sub": remote_user_id}
channel, grant = self.auth_via_oidc(
- fake_server, userinfo, client_redirect_url, with_sid=with_sid
+ fake_server,
+ userinfo,
+ client_redirect_url,
+ with_sid=with_sid,
+ idp_id=idp_id,
)
# expect a confirmation page
@@ -623,6 +656,7 @@ class RestHelper:
client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None,
with_sid: bool = False,
+ idp_id: Optional[str] = None,
) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider.
@@ -648,6 +682,7 @@ class RestHelper:
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth.
with_sid: if True, generates a random `sid` (OIDC session ID)
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -665,7 +700,9 @@ class RestHelper:
oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
else:
# otherwise, hit the login redirect endpoint
- oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+ oauth_uri = self.initiate_sso_login(
+ client_redirect_url, cookies, idp_id=idp_id
+ )
# we now have a URI for the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
@@ -742,7 +779,10 @@ class RestHelper:
return channel, grant
def initiate_sso_login(
- self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+ self,
+ client_redirect_url: Optional[str],
+ cookies: MutableMapping[str, str],
+ idp_id: Optional[str] = None,
) -> str:
"""Make a request to the login-via-sso redirect endpoint, and return the target
@@ -753,6 +793,7 @@ class RestHelper:
client_redirect_url: the client redirect URL to pass to the login redirect
endpoint
cookies: any cookies returned will be added to this dict
+ idp_id: if set, explicitely chooses one specific IDP
Returns:
the URI that the client gets redirected to (ie, the SSO server)
@@ -761,6 +802,12 @@ class RestHelper:
if client_redirect_url:
params["redirectUrl"] = client_redirect_url
+ uri = "/_matrix/client/r0/login/sso/redirect"
+ if idp_id is not None:
+ uri = f"{uri}/{idp_id}"
+
+ uri = f"{uri}?{urllib.parse.urlencode(params)}"
+
# hit the redirect url (which should redirect back to the redirect url. This
# is the easiest way of figuring out what the Host header ought to be set to
# to keep Synapse happy.
@@ -768,7 +815,7 @@ class RestHelper:
self.hs.get_reactor(),
self.site,
"GET",
- "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+ uri,
)
assert channel.code == 302
diff --git a/tests/server.py b/tests/server.py
index 8b1d186219..b1730fcc8d 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -362,6 +362,12 @@ def make_request(
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
+ # Old version of Twisted (<20.3.0) have issues with parsing x-www-form-urlencoded
+ # bodies if the Content-Length header is missing
+ req.requestHeaders.addRawHeader(
+ b"Content-Length", str(len(content)).encode("ascii")
+ )
+
if access_token:
req.requestHeaders.addRawHeader(
b"Authorization", b"Bearer " + access_token.encode("ascii")
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index de134bbc89..1461d23ee8 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -51,6 +51,8 @@ class FakeOidcServer:
get_userinfo_handler: Mock
post_token_handler: Mock
+ sid_counter: int = 0
+
def __init__(self, clock: Clock, issuer: str):
from authlib.jose import ECKey, KeySet
@@ -146,7 +148,7 @@ class FakeOidcServer:
return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
- now = self._clock.time()
+ now = int(self._clock.time())
id_token = {
**grant.userinfo,
"iss": self.issuer,
@@ -166,6 +168,26 @@ class FakeOidcServer:
return self._sign(id_token)
+ def generate_logout_token(self, grant: FakeAuthorizationGrant) -> str:
+ now = int(self._clock.time())
+ logout_token = {
+ "iss": self.issuer,
+ "aud": grant.client_id,
+ "iat": now,
+ "jti": random_string(10),
+ "events": {
+ "http://schemas.openid.net/event/backchannel-logout": {},
+ },
+ }
+
+ if grant.sid is not None:
+ logout_token["sid"] = grant.sid
+
+ if "sub" in grant.userinfo:
+ logout_token["sub"] = grant.userinfo["sub"]
+
+ return self._sign(logout_token)
+
def id_token_override(self, overrides: dict):
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -183,7 +205,8 @@ class FakeOidcServer:
code = random_string(10)
sid = None
if with_sid:
- sid = random_string(10)
+ sid = str(self.sid_counter)
+ self.sid_counter += 1
grant = FakeAuthorizationGrant(
userinfo=userinfo,
|