summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v1/utils.py196
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py40
-rw-r--r--tests/server.py32
3 files changed, 226 insertions, 42 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 81b7f84360..85d1709ead 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,7 +20,8 @@ import json
 import re
 import time
 import urllib.parse
-from typing import Any, Dict, Optional
+from html.parser import HTMLParser
+from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
 
 from mock import patch
 
@@ -32,7 +33,7 @@ from twisted.web.server import Site
 from synapse.api.constants import Membership
 from synapse.types import JsonDict
 
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
 from tests.test_utils import FakeResponse
 
 
@@ -362,34 +363,94 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
+        channel = self.auth_via_oidc(remote_user_id, client_redirect_url)
 
-        # first hit the redirect url (which will issue a cookie and state)
+        # expect a confirmation page
+        assert channel.code == 200
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
         channel = make_request(
             self.hs.get_reactor(),
             self.site,
-            "GET",
-            "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
         )
-        # that will redirect to the OIDC IdP, but we skip that and go straight
+        assert channel.code == 200
+        return channel.json_body
+
+    def auth_via_oidc(
+        self,
+        remote_user_id: str,
+        client_redirect_url: Optional[str] = None,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> FakeChannel:
+        """Perform an OIDC authentication flow via a mock OIDC provider.
+
+        This can be used for either login or user-interactive auth.
+
+        Starts by making a request to the relevant synapse redirect endpoint, which is
+        expected to serve a 302 to the OIDC provider. We then make a request to the
+        OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+        to the OIDC provider.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+
+        Args:
+            remote_user_id: the remote id that the OIDC provider should present
+            client_redirect_url: for a login flow, the client redirect URL to pass to
+                the login redirect endpoint
+            ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+                of the UI auth.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+            Note that the response code may be a 200, 302 or 400 depending on how things
+            went.
+        """
+
+        cookies = {}
+
+        # if we're doing a ui auth, hit the ui auth redirect endpoint
+        if ui_auth_session_id:
+            # can't set the client redirect url for UI Auth
+            assert client_redirect_url is None
+            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+        else:
+            # otherwise, hit the login redirect endpoint
+            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+        # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
-        # param that synapse passes to the IdP via query params, and the cookie that
-        # synapse passes to the client.
-        assert channel.code == 302
-        oauth_uri = channel.headers.getRawHeaders("Location")[0]
-        params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
-        redirect_uri = "%s?%s" % (
+        # param that synapse passes to the IdP via query params, as well as the cookie
+        # that synapse passes to the client.
+
+        oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
+        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+            "unexpected SSO URI " + oauth_uri_path
+        )
+        params = urllib.parse.parse_qs(oauth_uri_qs)
+        callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
             urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
         )
-        cookies = {}
-        for h in channel.headers.getRawHeaders("Set-Cookie"):
-            parts = h.split(";")
-            k, v = parts[0].split("=", maxsplit=1)
-            cookies[k] = v
 
         # before we hit the callback uri, stub out some methods in the http client so
         # that we don't have to handle full HTTPS requests.
-
         # (expected url, json response) pairs, in the order we expect them.
         expected_requests = [
             # first we get a hit to the token endpoint, which we tell to return
@@ -413,34 +474,97 @@ class RestHelper:
                 self.hs.get_reactor(),
                 self.site,
                 "GET",
-                redirect_uri,
+                callback_uri,
                 custom_headers=[
                     ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
                 ],
             )
+        return channel
 
-        # expect a confirmation page
-        assert channel.code == 200
+    def initiate_sso_login(
+        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the login-via-sso redirect endpoint, and return the target
 
-        # fish the matrix login token out of the body of the confirmation page
-        m = re.search(
-            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
-            channel.result["body"].decode("utf-8"),
-        )
-        assert m
-        login_token = m.group(1)
+        Assumes that exactly one SSO provider has been configured. Requires the login
+        servlet to be mounted.
 
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token and device id.
+        Args:
+            client_redirect_url: the client redirect URL to pass to the login redirect
+                endpoint
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets redirected to (ie, the SSO server)
+        """
+        params = {}
+        if client_redirect_url:
+            params["redirectUrl"] = client_redirect_url
+
+        # hit the redirect url (which will issue a cookie and state)
         channel = make_request(
             self.hs.get_reactor(),
             self.site,
-            "POST",
-            "/login",
-            content={"type": "m.login.token", "token": login_token},
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
         )
-        assert channel.code == 200
-        return channel.json_body
+
+        assert channel.code == 302
+        channel.extract_cookies(cookies)
+        return channel.headers.getRawHeaders("Location")[0]
+
+    def initiate_sso_ui_auth(
+        self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the
+        AuthRestServlet to be mounted.
+
+        Args:
+            ui_auth_session_id: the session id of the UI auth
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets linked to (ie, the SSO server)
+        """
+        sso_redirect_endpoint = (
+            "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+            + urllib.parse.urlencode({"session": ui_auth_session_id})
+        )
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+        )
+        # that should serve a confirmation page
+        assert channel.code == 200, channel.text_body
+        channel.extract_cookies(cookies)
+
+        # parse the confirmation page to fish out the link.
+        class ConfirmationPageParser(HTMLParser):
+            def __init__(self):
+                super().__init__()
+
+                self.links = []  # type: List[str]
+
+            def handle_starttag(
+                self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+            ) -> None:
+                attr_dict = dict(attrs)
+                if tag == "a":
+                    href = attr_dict["href"]
+                    if href:
+                        self.links.append(href)
+
+            def error(_, message):
+                raise AssertionError(message)
+
+        p = ConfirmationPageParser()
+        p.feed(channel.text_body)
+        p.close()
+        assert len(p.links) == 1, "not exactly one link in confirmation page"
+        oauth_uri = p.links[0]
+        return oauth_uri
 
 
 # an 'oidc_config' suitable for login_via_oidc.
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index bb91e0c331..5f6ca23b06 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2018 New Vector
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 from typing import Union
 
 from twisted.internet.defer import succeed
@@ -386,6 +386,44 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
+    def test_ui_auth_via_sso(self):
+        """Test a successful UI Auth flow via SSO
+
+        This includes:
+          * hitting the UIA SSO redirect endpoint
+          * checking it serves a confirmation page which links to the OIDC provider
+          * calling back to the synapse oidc callback
+          * checking that the original operation succeeds
+        """
+
+        # log the user in
+        remote_user_id = UserID.from_string(self.user).localpart
+        login_resp = self.helper.login_via_oidc(remote_user_id)
+        self.assertEqual(login_resp["user_id"], self.user)
+
+        # initiate a UI Auth process by attempting to delete the device
+        channel = self.delete_device(self.user_tok, self.device_id, 401)
+
+        # check that SSO is offered
+        flows = channel.json_body["flows"]
+        self.assertIn({"stages": ["m.login.sso"]}, flows)
+
+        # run the UIA-via-SSO flow
+        session_id = channel.json_body["session"]
+        channel = self.helper.auth_via_oidc(
+            remote_user_id=remote_user_id, ui_auth_session_id=session_id
+        )
+
+        # that should serve a confirmation page
+        self.assertEqual(channel.code, 200, channel.result)
+
+        # and now the delete request should succeed.
+        self.delete_device(
+            self.user_tok, self.device_id, 200, body={"auth": {"session": session_id}},
+        )
+
+    @skip_unless(HAS_OIDC, "requires OIDC")
+    @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self):
         login_resp = self.helper.login_via_oidc("username")
         user_tok = login_resp["access_token"]
diff --git a/tests/server.py b/tests/server.py
index 7d1ad362c4..5a1b66270f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -2,7 +2,7 @@ import json
 import logging
 from collections import deque
 from io import SEEK_END, BytesIO
-from typing import Callable, Iterable, Optional, Tuple, Union
+from typing import Callable, Iterable, MutableMapping, Optional, Tuple, Union
 
 import attr
 from typing_extensions import Deque
@@ -51,9 +51,21 @@ class FakeChannel:
 
     @property
     def json_body(self):
-        if not self.result:
-            raise Exception("No result yet.")
-        return json.loads(self.result["body"].decode("utf8"))
+        return json.loads(self.text_body)
+
+    @property
+    def text_body(self) -> str:
+        """The body of the result, utf-8-decoded.
+
+        Raises an exception if the request has not yet completed.
+        """
+        if not self.is_finished:
+            raise Exception("Request not yet completed")
+        return self.result["body"].decode("utf8")
+
+    def is_finished(self) -> bool:
+        """check if the response has been completely received"""
+        return self.result.get("done", False)
 
     @property
     def code(self):
@@ -124,7 +136,7 @@ class FakeChannel:
         self._reactor.run()
         x = 0
 
-        while not self.result.get("done"):
+        while not self.is_finished():
             # If there's a producer, tell it to resume producing so we get content
             if self._producer:
                 self._producer.resumeProducing()
@@ -136,6 +148,16 @@ class FakeChannel:
 
             self._reactor.advance(0.1)
 
+    def extract_cookies(self, cookies: MutableMapping[str, str]) -> None:
+        """Process the contents of any Set-Cookie headers in the response
+
+        Any cookines found are added to the given dict
+        """
+        for h in self.headers.getRawHeaders("Set-Cookie"):
+            parts = h.split(";")
+            k, v = parts[0].split("=", maxsplit=1)
+            cookies[k] = v
+
 
 class FakeSite:
     """