diff options
Diffstat (limited to 'tests/rest/client/v1/utils.py')
-rw-r--r-- | tests/rest/client/v1/utils.py | 62 |
1 files changed, 36 insertions, 26 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": "<remote user id>"}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } |