diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c249a42bb6..967d229223 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
Tuple,
overload,
)
-from unittest.mock import patch
from urllib.parse import urlencode
import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "issuer": TEST_OIDC_ISSUER,
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["openid"],
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
@attr.s(auto_attribs=True)
@@ -543,12 +553,28 @@ class RestHelper:
return channel.json_body
+ def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+ """Create a ``FakeOidcServer``.
+
+ This can be used in conjuction with ``login_via_oidc``::
+
+ fake_oidc_server = self.helper.fake_oidc_server()
+ login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+ """
+
+ return FakeOidcServer(
+ clock=self.hs.get_clock(),
+ issuer=issuer,
+ )
+
def login_via_oidc(
self,
+ fake_server: FakeOidcServer,
remote_user_id: str,
+ with_sid: bool = False,
expected_status: int = 200,
- ) -> JsonDict:
- """Log in via OIDC
+ ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
+ """Log in (as a new user) via OIDC
Returns the result of the final token login.
@@ -560,7 +586,10 @@ class RestHelper:
the normal places.
"""
client_redirect_url = "https://x"
- channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+ userinfo = {"sub": remote_user_id}
+ channel, grant = self.auth_via_oidc(
+ fake_server, userinfo, client_redirect_url, with_sid=with_sid
+ )
# expect a confirmation page
assert channel.code == HTTPStatus.OK, channel.result
@@ -585,14 +614,16 @@ class RestHelper:
assert (
channel.code == expected_status
), f"unexpected status in response: {channel.code}"
- return channel.json_body
+ return channel.json_body, grant
def auth_via_oidc(
self,
+ fake_server: FakeOidcServer,
user_info_dict: JsonDict,
client_redirect_url: Optional[str] = None,
ui_auth_session_id: Optional[str] = None,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Perform an OIDC authentication flow via a mock OIDC provider.
This can be used for either login or user-interactive auth.
@@ -616,6 +647,7 @@ class RestHelper:
the login redirect endpoint
ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
of the UI auth.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -625,14 +657,15 @@ class RestHelper:
cookies: Dict[str, str] = {}
- # if we're doing a ui auth, hit the ui auth redirect endpoint
- if ui_auth_session_id:
- # can't set the client redirect url for UI Auth
- assert client_redirect_url is None
- oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
- else:
- # otherwise, hit the login redirect endpoint
- oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+ with fake_server.patch_homeserver(hs=self.hs):
+ # if we're doing a ui auth, hit the ui auth redirect endpoint
+ if ui_auth_session_id:
+ # can't set the client redirect url for UI Auth
+ assert client_redirect_url is None
+ oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+ else:
+ # otherwise, hit the login redirect endpoint
+ oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
# we now have a URI for the OIDC IdP, but we skip that and go straight
# back to synapse's OIDC callback resource. However, we do need the "state"
@@ -640,17 +673,21 @@ class RestHelper:
# that synapse passes to the client.
oauth_uri_path, _ = oauth_uri.split("?", 1)
- assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+ assert oauth_uri_path == fake_server.authorization_endpoint, (
"unexpected SSO URI " + oauth_uri_path
)
- return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+ return self.complete_oidc_auth(
+ fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+ )
def complete_oidc_auth(
self,
+ fake_serer: FakeOidcServer,
oauth_uri: str,
cookies: Mapping[str, str],
user_info_dict: JsonDict,
- ) -> FakeChannel:
+ with_sid: bool = False,
+ ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
"""Mock out an OIDC authentication flow
Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -661,50 +698,37 @@ class RestHelper:
Requires the OIDC callback resource to be mounted at the normal place.
Args:
+ fake_server: the fake OIDC server with which the auth should be done
oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
from initiate_sso_login or initiate_sso_ui_auth).
cookies: the cookies set by synapse's redirect endpoint, which will be
sent back to the callback endpoint.
user_info_dict: the remote userinfo that the OIDC provider should present.
Typically this should be '{"sub": "<remote user id>"}'.
+ with_sid: if True, generates a random `sid` (OIDC session ID)
Returns:
A FakeChannel containing the result of calling the OIDC callback endpoint.
"""
_, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
+
+ code, grant = fake_serer.start_authorization(
+ scope=params["scope"][0],
+ userinfo=user_info_dict,
+ client_id=params["client_id"][0],
+ redirect_uri=params["redirect_uri"][0],
+ nonce=params["nonce"][0],
+ with_sid=with_sid,
+ )
+ state = params["state"][0]
+
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
- urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ urllib.parse.urlencode({"state": state, "code": code}),
)
- # before we hit the callback uri, stub out some methods in the http client so
- # that we don't have to handle full HTTPS requests.
- # (expected url, json response) pairs, in the order we expect them.
- expected_requests = [
- # first we get a hit to the token endpoint, which we tell to return
- # a dummy OIDC access token
- (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
- # and then one to the user_info endpoint, which returns our remote user id.
- (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
- ]
-
- async def mock_req(
- method: str,
- uri: str,
- data: Optional[dict] = None,
- headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
- ):
- (expected_uri, resp_obj) = expected_requests.pop(0)
- assert uri == expected_uri
- resp = FakeResponse(
- code=HTTPStatus.OK,
- phrase=b"OK",
- body=json.dumps(resp_obj).encode("utf-8"),
- )
- return resp
-
- with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ with fake_serer.patch_homeserver(hs=self.hs):
# now hit the callback URI with the right params and a made-up code
channel = make_request(
self.hs.get_reactor(),
@@ -715,7 +739,7 @@ class RestHelper:
("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
],
)
- return channel
+ return channel, grant
def initiate_sso_login(
self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
@@ -806,21 +830,3 @@ class RestHelper:
assert len(p.links) == 1, "not exactly one link in confirmation page"
oauth_uri = p.links[0]
return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
- "enabled": True,
- "discover": False,
- "issuer": "https://issuer.test",
- "client_id": "test-client-id",
- "client_secret": "test-client-secret",
- "scopes": ["profile"],
- "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
- "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
- "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
- "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
|