summary refs log tree commit diff
path: root/tests/rest/client/v1/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/v1/utils.py')
-rw-r--r--tests/rest/client/v1/utils.py268
1 files changed, 258 insertions, 10 deletions
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
 # Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -17,8 +17,12 @@
 # limitations under the License.
 
 import json
+import re
 import time
-from typing import Any, Dict, Optional
+import urllib.parse
+from typing import Any, Dict, Mapping, MutableMapping, Optional
+
+from mock import patch
 
 import attr
 
@@ -26,8 +30,11 @@ from twisted.web.resource import Resource
 from twisted.web.server import Site
 
 from synapse.api.constants import Membership
+from synapse.types import JsonDict
 
-from tests.server import FakeSite, make_request
+from tests.server import FakeChannel, FakeSite, make_request
+from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
 
 
 @attr.s
@@ -75,7 +82,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "POST",
@@ -151,7 +158,7 @@ class RestHelper:
         data = {"membership": membership}
         data.update(extra_data)
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -186,7 +193,7 @@ class RestHelper:
         if tok:
             path = path + "?access_token=%s" % tok
 
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             self.site,
             "PUT",
@@ -242,9 +249,7 @@ class RestHelper:
         if body is not None:
             content = json.dumps(body).encode("utf8")
 
-        _, channel = make_request(
-            self.hs.get_reactor(), self.site, method, path, content
-        )
+        channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
 
         assert int(channel.result["code"]) == expect_code, (
             "Expected: %d, got: %d, resp: %r"
@@ -327,7 +332,7 @@ class RestHelper:
         """
         image_length = len(image_data)
         path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
-        _, channel = make_request(
+        channel = make_request(
             self.hs.get_reactor(),
             FakeSite(resource),
             "POST",
@@ -344,3 +349,246 @@ class RestHelper:
         )
 
         return channel.json_body
+
+    def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+        """Log in (as a new user) via OIDC
+
+        Returns the result of the final token login.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+        """
+        client_redirect_url = "https://x"
+        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+
+        # expect a confirmation page
+        assert channel.code == 200, channel.result
+
+        # fish the matrix login token out of the body of the confirmation page
+        m = re.search(
+            'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+            channel.text_body,
+        )
+        assert m, channel.text_body
+        login_token = m.group(1)
+
+        # finally, submit the matrix login token to the login API, which gives us our
+        # matrix access token and device id.
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "POST",
+            "/login",
+            content={"type": "m.login.token", "token": login_token},
+        )
+        assert channel.code == 200
+        return channel.json_body
+
+    def auth_via_oidc(
+        self,
+        user_info_dict: JsonDict,
+        client_redirect_url: Optional[str] = None,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> FakeChannel:
+        """Perform an OIDC authentication flow via a mock OIDC provider.
+
+        This can be used for either login or user-interactive auth.
+
+        Starts by making a request to the relevant synapse redirect endpoint, which is
+        expected to serve a 302 to the OIDC provider. We then make a request to the
+        OIDC callback endpoint, intercepting the HTTP requests that will get sent back
+        to the OIDC provider.
+
+        Requires that "oidc_config" in the homeserver config be set appropriately
+        (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+        "public_base_url".
+
+        Also requires the login servlet and the OIDC callback resource to be mounted at
+        the normal places.
+
+        Args:
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+            client_redirect_url: for a login flow, the client redirect URL to pass to
+                the login redirect endpoint
+            ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
+                of the UI auth.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+            Note that the response code may be a 200, 302 or 400 depending on how things
+            went.
+        """
+
+        cookies = {}
+
+        # if we're doing a ui auth, hit the ui auth redirect endpoint
+        if ui_auth_session_id:
+            # can't set the client redirect url for UI Auth
+            assert client_redirect_url is None
+            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+        else:
+            # otherwise, hit the login redirect endpoint
+            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+
+        # we now have a URI for the OIDC IdP, but we skip that and go straight
+        # back to synapse's OIDC callback resource. However, we do need the "state"
+        # param that synapse passes to the IdP via query params, as well as the cookie
+        # that synapse passes to the client.
+
+        oauth_uri_path, _ = oauth_uri.split("?", 1)
+        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+            "unexpected SSO URI " + oauth_uri_path
+        )
+        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+    def complete_oidc_auth(
+        self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+    ) -> FakeChannel:
+        """Mock out an OIDC authentication flow
+
+        Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+        initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+        Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+        sent back to the OIDC provider.
+
+        Requires the OIDC callback resource to be mounted at the normal place.
+
+        Args:
+            oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+               from initiate_sso_login or initiate_sso_ui_auth).
+            cookies: the cookies set by synapse's redirect endpoint, which will be
+               sent back to the callback endpoint.
+            user_info_dict: the remote userinfo that the OIDC provider should present.
+                Typically this should be '{"sub": "<remote user id>"}'.
+
+        Returns:
+            A FakeChannel containing the result of calling the OIDC callback endpoint.
+        """
+        _, oauth_uri_qs = oauth_uri.split("?", 1)
+        params = urllib.parse.parse_qs(oauth_uri_qs)
+        callback_uri = "%s?%s" % (
+            urllib.parse.urlparse(params["redirect_uri"][0]).path,
+            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+        )
+
+        # before we hit the callback uri, stub out some methods in the http client so
+        # that we don't have to handle full HTTPS requests.
+        # (expected url, json response) pairs, in the order we expect them.
+        expected_requests = [
+            # first we get a hit to the token endpoint, which we tell to return
+            # a dummy OIDC access token
+            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
+            # and then one to the user_info endpoint, which returns our remote user id.
+            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
+        ]
+
+        async def mock_req(method: str, uri: str, data=None, headers=None):
+            (expected_uri, resp_obj) = expected_requests.pop(0)
+            assert uri == expected_uri
+            resp = FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+            )
+            return resp
+
+        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+            # now hit the callback URI with the right params and a made-up code
+            channel = make_request(
+                self.hs.get_reactor(),
+                self.site,
+                "GET",
+                callback_uri,
+                custom_headers=[
+                    ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+                ],
+            )
+        return channel
+
+    def initiate_sso_login(
+        self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the login-via-sso redirect endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the login
+        servlet to be mounted.
+
+        Args:
+            client_redirect_url: the client redirect URL to pass to the login redirect
+                endpoint
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets redirected to (ie, the SSO server)
+        """
+        params = {}
+        if client_redirect_url:
+            params["redirectUrl"] = client_redirect_url
+
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(),
+            self.site,
+            "GET",
+            "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
+        )
+
+        assert channel.code == 302
+        channel.extract_cookies(cookies)
+        return channel.headers.getRawHeaders("Location")[0]
+
+    def initiate_sso_ui_auth(
+        self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
+    ) -> str:
+        """Make a request to the ui-auth-via-sso endpoint, and return the target
+
+        Assumes that exactly one SSO provider has been configured. Requires the
+        AuthRestServlet to be mounted.
+
+        Args:
+            ui_auth_session_id: the session id of the UI auth
+            cookies: any cookies returned will be added to this dict
+
+        Returns:
+            the URI that the client gets linked to (ie, the SSO server)
+        """
+        sso_redirect_endpoint = (
+            "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
+            + urllib.parse.urlencode({"session": ui_auth_session_id})
+        )
+        # hit the redirect url (which will issue a cookie and state)
+        channel = make_request(
+            self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
+        )
+        # that should serve a confirmation page
+        assert channel.code == 200, channel.text_body
+        channel.extract_cookies(cookies)
+
+        # parse the confirmation page to fish out the link.
+        p = TestHtmlParser()
+        p.feed(channel.text_body)
+        p.close()
+        assert len(p.links) == 1, "not exactly one link in confirmation page"
+        oauth_uri = p.links[0]
+        return oauth_uri
+
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "discover": False,
+    "issuer": "https://issuer.test",
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["profile"],
+    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
+    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}