summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py152
1 files changed, 144 insertions, 8 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 49a1842b5c..adddbd002f 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -396,6 +396,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(params["client_id"], [CLIENT_ID])
         self.assertEqual(len(params["state"]), 1)
         self.assertEqual(len(params["nonce"]), 1)
+        self.assertNotIn("code_challenge", params)
 
         # Check what is in the cookies
         self.assertEqual(len(req.cookies), 2)  # two cookies
@@ -411,13 +412,118 @@ class OidcHandlerTestCase(HomeserverTestCase):
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
         state = get_value_from_macaroon(macaroon, "state")
         nonce = get_value_from_macaroon(macaroon, "nonce")
+        code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
         redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
+        self.assertEqual(code_verifier, "")
         self.assertEqual(redirect, "http://client/redirect")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
+    def test_redirect_request_with_code_challenge(self) -> None:
+        """The redirect request has the right arguments & generates a valid session cookie."""
+        req = Mock(spec=["cookies"])
+        req.cookies = []
+
+        with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
+            url = urlparse(
+                self.get_success(
+                    self.provider.handle_redirect_request(
+                        req, b"http://client/redirect"
+                    )
+                )
+            )
+
+        # Ensure the code_challenge param is added to the redirect.
+        params = parse_qs(url.query)
+        self.assertEqual(len(params["code_challenge"]), 1)
+
+        # Check what is in the cookies
+        self.assertEqual(len(req.cookies), 2)  # two cookies
+        cookie_header = req.cookies[0]
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        parts = [p.strip() for p in cookie_header.split(b";")]
+        self.assertIn(b"Path=/_synapse/client/oidc", parts)
+        name, cookie = parts[0].split(b"=")
+        self.assertEqual(name, b"oidc_session")
+
+        # Ensure the code_verifier is set in the cookie.
+        macaroon = pymacaroons.Macaroon.deserialize(cookie)
+        code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+        self.assertNotEqual(code_verifier, "")
+
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "always"}})
+    def test_redirect_request_with_forced_code_challenge(self) -> None:
+        """The redirect request has the right arguments & generates a valid session cookie."""
+        req = Mock(spec=["cookies"])
+        req.cookies = []
+
+        url = urlparse(
+            self.get_success(
+                self.provider.handle_redirect_request(req, b"http://client/redirect")
+            )
+        )
+
+        # Ensure the code_challenge param is added to the redirect.
+        params = parse_qs(url.query)
+        self.assertEqual(len(params["code_challenge"]), 1)
+
+        # Check what is in the cookies
+        self.assertEqual(len(req.cookies), 2)  # two cookies
+        cookie_header = req.cookies[0]
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        parts = [p.strip() for p in cookie_header.split(b";")]
+        self.assertIn(b"Path=/_synapse/client/oidc", parts)
+        name, cookie = parts[0].split(b"=")
+        self.assertEqual(name, b"oidc_session")
+
+        # Ensure the code_verifier is set in the cookie.
+        macaroon = pymacaroons.Macaroon.deserialize(cookie)
+        code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+        self.assertNotEqual(code_verifier, "")
+
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "pkce_method": "never"}})
+    def test_redirect_request_with_disabled_code_challenge(self) -> None:
+        """The redirect request has the right arguments & generates a valid session cookie."""
+        req = Mock(spec=["cookies"])
+        req.cookies = []
+
+        # The metadata should state that PKCE is enabled.
+        with self.metadata_edit({"code_challenge_methods_supported": ["S256"]}):
+            url = urlparse(
+                self.get_success(
+                    self.provider.handle_redirect_request(
+                        req, b"http://client/redirect"
+                    )
+                )
+            )
+
+        # Ensure the code_challenge param is added to the redirect.
+        params = parse_qs(url.query)
+        self.assertNotIn("code_challenge", params)
+
+        # Check what is in the cookies
+        self.assertEqual(len(req.cookies), 2)  # two cookies
+        cookie_header = req.cookies[0]
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        parts = [p.strip() for p in cookie_header.split(b";")]
+        self.assertIn(b"Path=/_synapse/client/oidc", parts)
+        name, cookie = parts[0].split(b"=")
+        self.assertEqual(name, b"oidc_session")
+
+        # Ensure the code_verifier is blank in the cookie.
+        macaroon = pymacaroons.Macaroon.deserialize(cookie)
+        code_verifier = get_value_from_macaroon(macaroon, "code_verifier")
+        self.assertEqual(code_verifier, "")
+
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_error(self) -> None:
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
@@ -601,7 +707,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             payload=token
         )
         code = "code"
-        ret = self.get_success(self.provider._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
         kwargs = self.fake_server.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -615,13 +721,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(args["client_secret"], [CLIENT_SECRET])
         self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
 
+        # Test providing a code verifier.
+        code_verifier = "code_verifier"
+        ret = self.get_success(
+            self.provider._exchange_code(code, code_verifier=code_verifier)
+        )
+        kwargs = self.fake_server.request.call_args[1]
+
+        self.assertEqual(ret, token)
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
+
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["client_secret"], [CLIENT_SECRET])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+        self.assertEqual(args["code_verifier"], [code_verifier])
+
         # Test error handling
         self.fake_server.post_token_handler.return_value = FakeResponse.json(
             code=400, payload={"error": "foo", "error_description": "bar"}
         )
         from synapse.handlers.oidc import OidcError
 
-        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+        exc = self.get_failure(
+            self.provider._exchange_code(code, code_verifier=""), OidcError
+        )
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -629,7 +756,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.fake_server.post_token_handler.return_value = FakeResponse(
             code=500, body=b"Not JSON"
         )
-        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+        exc = self.get_failure(
+            self.provider._exchange_code(code, code_verifier=""), OidcError
+        )
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -637,21 +766,27 @@ class OidcHandlerTestCase(HomeserverTestCase):
             code=500, payload={"error": "internal_server_error"}
         )
 
-        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+        exc = self.get_failure(
+            self.provider._exchange_code(code, code_verifier=""), OidcError
+        )
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.fake_server.post_token_handler.return_value = FakeResponse.json(
             code=400, payload={}
         )
-        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+        exc = self.get_failure(
+            self.provider._exchange_code(code, code_verifier=""), OidcError
+        )
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
         self.fake_server.post_token_handler.return_value = FakeResponse.json(
             code=200, payload={"error": "some_error"}
         )
-        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
+        exc = self.get_failure(
+            self.provider._exchange_code(code, code_verifier=""), OidcError
+        )
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -688,7 +823,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # timestamps.
         self.reactor.advance(1000)
         start_time = self.reactor.seconds()
-        ret = self.get_success(self.provider._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
 
         self.assertEqual(ret, token)
 
@@ -739,7 +874,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             payload=token
         )
         code = "code"
-        ret = self.get_success(self.provider._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code, code_verifier=""))
 
         self.assertEqual(ret, token)
 
@@ -1203,6 +1338,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
+                code_verifier="",
             ),
         )