summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py268
1 files changed, 84 insertions, 184 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index f5df657814..b3dfa40d25 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,20 +13,14 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-import re
-from typing import Dict
-from urllib.parse import parse_qs, urlencode, urlparse
+from typing import Optional
+from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
 
 import pymacaroons
 
-from twisted.web.resource import Resource
-
-from synapse.api.errors import RedirectException
 from synapse.handlers.sso import MappingException
-from synapse.rest.client.v1 import login
-from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.server import HomeServer
 from synapse.types import UserID
 
@@ -151,6 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._providers["oidc"]
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -162,9 +157,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
@@ -175,15 +171,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -195,47 +191,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -314,13 +310,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -349,9 +345,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -384,7 +384,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # ensure that we are correctly testing the fallback when "get_extra_attributes"
         # is not implemented.
-        mapping_provider = self.handler._user_mapping_provider
+        mapping_provider = self.provider._user_mapping_provider
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
@@ -399,9 +399,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -411,12 +411,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
-        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
         request = _build_callback_request(
             code, state, session, user_agent=user_agent, ip_address=ip_address
         )
@@ -426,14 +421,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
         with patch.object(
-            self.handler,
+            self.provider,
             "_remote_id_from_userinfo",
             new=Mock(side_effect=MappingException()),
         ):
@@ -441,36 +436,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
         auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
             expected_user_id, request, client_redirect_url, None,
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc_handler import OidcError
 
-        self.handler._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -500,11 +495,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -528,7 +520,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -552,7 +544,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
         from synapse.handlers.oidc_handler import OidcError
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -562,7 +554,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -574,14 +566,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -590,7 +582,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -616,18 +608,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        session = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
         request = _build_callback_request("code", state, session)
 
@@ -841,116 +830,25 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
 
-class UsernamePickerTestCase(HomeserverTestCase):
-    if not HAS_OIDC:
-        skip = "requires OIDC"
-
-    servlets = [login.register_servlets]
-
-    def default_config(self):
-        config = super().default_config()
-        config["public_baseurl"] = BASE_URL
-        oidc_config = {
-            "enabled": True,
-            "client_id": CLIENT_ID,
-            "client_secret": CLIENT_SECRET,
-            "issuer": ISSUER,
-            "scopes": SCOPES,
-            "user_mapping_provider": {
-                "config": {"display_name_template": "{{ user.displayname }}"}
-            },
-        }
-
-        # Update this config with what's in the default config so that
-        # override_config works as expected.
-        oidc_config.update(config.get("oidc_config", {}))
-        config["oidc_config"] = oidc_config
-
-        # whitelist this client URI so we redirect straight to it rather than
-        # serving a confirmation page
-        config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
-        return config
-
-    def create_resource_dict(self) -> Dict[str, Resource]:
-        d = super().create_resource_dict()
-        d["/_synapse/client/pick_username"] = pick_username_resource(self.hs)
-        return d
-
-    def test_username_picker(self):
-        """Test the happy path of a username picker flow."""
-        client_redirect_url = "https://whitelisted.client"
-
-        # first of all, mock up an OIDC callback to the OidcHandler, which should
-        # raise a RedirectException
-        userinfo = {"sub": "tester", "displayname": "Jonny"}
-        f = self.get_failure(
-            _make_callback_with_userinfo(
-                self.hs, userinfo, client_redirect_url=client_redirect_url
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                idp_id="oidc",
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
             ),
-            RedirectException,
-        )
-
-        # check the Location and cookies returned by the RedirectException
-        self.assertEqual(f.value.location, b"/_synapse/client/pick_username")
-        cookieheader = f.value.cookies[0]
-        regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);")
-        m = regex.search(cookieheader)
-        if not m:
-            self.fail("cookie header %s does not match %s" % (cookieheader, regex))
-
-        # introspect the sso handler a bit to check that the username mapping session
-        # looks ok.
-        session_id = m.group(1).decode("ascii")
-        username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions
-        self.assertIn(
-            session_id, username_mapping_sessions, "session id not found in map"
-        )
-        session = username_mapping_sessions[session_id]
-        self.assertEqual(session.remote_user_id, "tester")
-        self.assertEqual(session.display_name, "Jonny")
-        self.assertEqual(session.client_redirect_url, client_redirect_url)
-
-        # the expiry time should be about 15 minutes away
-        expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
-        self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000)
-
-        # Now, submit a username to the username picker, which should serve a redirect
-        # back to the client
-        submit_path = f.value.location + b"/submit"
-        content = urlencode({b"username": b"bobby"}).encode("utf8")
-        chan = self.make_request(
-            "POST",
-            path=submit_path,
-            content=content,
-            content_is_form=True,
-            custom_headers=[
-                ("Cookie", cookieheader),
-                # old versions of twisted don't do form-parsing without a valid
-                # content-length header.
-                ("Content-Length", str(len(content))),
-            ],
-        )
-        self.assertEqual(chan.code, 302, chan.result)
-        location_headers = chan.headers.getRawHeaders("Location")
-        # ensure that the returned location starts with the requested redirect URL
-        self.assertEqual(
-            location_headers[0][: len(client_redirect_url)], client_redirect_url
         )
 
-        # fish the login token out of the returned redirect uri
-        parts = urlparse(location_headers[0])
-        query = parse_qs(parts.query)
-        login_token = query["loginToken"][0]
-
-        # finally, submit the matrix login token to the login API, which gives us our
-        # matrix access token, mxid, and device id.
-        chan = self.make_request(
-            "POST", "/login", content={"type": "m.login.token", "token": login_token},
-        )
-        self.assertEqual(chan.code, 200, chan.result)
-        self.assertEqual(chan.json_body["user_id"], "@bobby:test")
-
 
 async def _make_callback_with_userinfo(
     hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
@@ -965,17 +863,20 @@ async def _make_callback_with_userinfo(
         userinfo: the OIDC userinfo dict
         client_redirect_url: the URL to redirect to on success.
     """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
     handler = hs.get_oidc_handler()
-    handler._exchange_code = simple_async_mock(return_value={})
-    handler._parse_id_token = simple_async_mock(return_value=userinfo)
-    handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider = handler._providers["oidc"]
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
 
     state = "state"
-    session = handler._generate_oidc_session_token(
+    session = handler._token_generator.generate_oidc_session_token(
         state=state,
-        nonce="nonce",
-        client_redirect_url=client_redirect_url,
-        ui_auth_session_id=None,
+        session_data=OidcSessionData(
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
     )
     request = _build_callback_request("code", state, session)
 
@@ -1011,7 +912,7 @@ def _build_callback_request(
             "addCookie",
             "requestHeaders",
             "getClientIP",
-            "get_user_agent",
+            "getHeader",
         ]
     )
 
@@ -1020,5 +921,4 @@ def _build_callback_request(
     request.args[b"code"] = [code.encode("utf-8")]
     request.args[b"state"] = [state.encode("utf-8")]
     request.getClientIP.return_value = ip_address
-    request.get_user_agent.return_value = user_agent
     return request