summary refs log tree commit diff
path: root/tests/rest/client/v1/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1/utils.py')
-rw-r--r--tests/rest/client/v1/utils.py62
1 files changed, 36 insertions, 26 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c6647dbe08..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -20,8 +20,7 @@ import json
 import re
 import time
 import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
+from typing import Any, Dict, Mapping, MutableMapping, Optional
 
 from mock import patch
 
@@ -35,6 +34,7 @@ from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
 from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
 
 
 @attr.s
@@ -440,10 +440,36 @@ class RestHelper:
         # param that synapse passes to the IdP via query params, as well as the cookie
         # that synapse passes to the client.
 
-        oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
+        oauth_uri_path, _ = oauth_uri.split("?", 1)
         assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
             "unexpected SSO URI " + oauth_uri_path
         )
+        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+    def complete_oidc_auth(
+        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+    ) -> FakeChannel:
+        """Mock out an OIDC authentication flow
+
+        Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+        initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+        Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+        sent back to the OIDC provider.
+
+        Requires the OIDC callback resource to be mounted at the normal place.
+
+        Args:
+            oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+               from initiate_sso_login or initiate_sso_ui_auth).
+            cookies: the cookies set by synapse's redirect endpoint, which will be
+               sent back to the callback endpoint.
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+        """
+        _, oauth_uri_qs = oauth_uri.split("?", 1)
         params = urllib.parse.parse_qs(oauth_uri_qs)
         callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
@@ -456,9 +482,9 @@ class RestHelper:
         expected_requests = [
             # first we get a hit to the token endpoint, which we tell to return
             # a dummy OIDC access token
-            ("https://issuer.test/token", {"access_token": "TEST"}),
+            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
             # and then one to the user_info endpoint, which returns our remote user id.
-            ("https://issuer.test/userinfo", user_info_dict),
+            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
         ]
 
         async def mock_req(method: str, uri: str, data=None, headers=None):
@@ -542,25 +568,7 @@ class RestHelper:
         channel.extract_cookies(cookies)
 
         # parse the confirmation page to fish out the link.
-        class ConfirmationPageParser(HTMLParser):
-            def __init__(self):
-                super().__init__()
-
-                self.links = []  # type: List[str]
-
-            def handle_starttag(
-                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
-            ) -> None:
-                attr_dict = dict(attrs)
-                if tag == "a":
-                    href = attr_dict["href"]
-                    if href:
-                        self.links.append(href)
-
-            def error(_, message):
-                raise AssertionError(message)
-
-        p = ConfirmationPageParser()
+        p = TestHtmlParser()
         p.feed(channel.text_body)
         p.close()
         assert len(p.links) == 1, "not exactly one link in confirmation page"
@@ -570,6 +578,8 @@ class RestHelper:
 
 # an 'oidc_config' suitable for login_via_oidc.
 TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
 TEST_OIDC_CONFIG = {
     "enabled": True,
     "discover": False,
@@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = {
     "client_secret": "test-client-secret",
     "scopes": ["profile"],
     "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
-    "token_endpoint": "https://issuer.test/token",
-    "userinfo_endpoint": "https://issuer.test/userinfo",
+    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
     "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
 }