summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py94
1 files changed, 50 insertions, 44 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
 # limitations under the License.
 import json
 import os
+from typing import Any, Dict
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
         # Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
     elif url == JWKS_URI:
         return {"keys": []}
 
+    return {}
+
 
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
 
-    def default_config(self):
+    def default_config(self) -> Dict[str, Any]:
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
         return config
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.http_client = Mock(spec=["get_json"])
         self.http_client.get_json.side_effect = get_json
         self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
-        sso_handler.render_error = self.render_error
+        sso_handler.render_error = self.render_error  # type: ignore[assignment]
 
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return args
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_config(self):
+    def test_config(self) -> None:
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
-    def test_discovery(self):
+    def test_discovery(self) -> None:
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_no_discovery(self):
+    def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
-    def test_load_jwks(self):
+    def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_validate_config(self):
+    def test_validate_config(self) -> None:
         """Provider metadatas are extensively validated."""
         h = self.provider
 
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             force_load_metadata()
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
-    def test_skip_verification(self):
+    def test_skip_verification(self) -> None:
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_redirect_request(self):
+    def test_redirect_request(self) -> None:
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
         req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(redirect, "http://client/redirect")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_error(self):
+    def test_callback_error(self) -> None:
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_client", "some description")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback(self):
+    def test_callback(self) -> None:
         """Code callback works and display errors if something went wrong.
 
         A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "type": "bearer",
             "access_token": "access_token",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         id_token = {
             "sid": "abcdefgh",
         }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)
-        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
         auth_handler.complete_sso_login.reset_mock()
         self.provider._fetch_userinfo.reset_mock()
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
         from synapse.handlers.oidc import OidcError
 
-        self.provider._exchange_code = simple_async_mock(
+        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_callback_session(self):
+    def test_callback_session(self) -> None:
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
 
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
     )
-    def test_exchange_code(self):
+    def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
         token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_jwt_key(self):
+    def test_exchange_code_jwt_key(self) -> None:
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_exchange_code_no_auth(self):
+    def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
         token = {"type": "bearer"}
         self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_extra_attributes(self):
+    def test_extra_attributes(self) -> None:
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_user(self):
+    def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
 
-        userinfo = {
+        userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
-    def test_map_userinfo_to_existing_user(self):
+    def test_map_userinfo_to_existing_user(self) -> None:
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastores().main
         user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_map_userinfo_to_invalid_localpart(self):
+    def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
             _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_map_userinfo_to_user_retries(self):
+    def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
-    def test_empty_localpart(self):
+    def test_empty_localpart(self) -> None:
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
             "sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_null_localpart(self):
+    def test_null_localpart(self) -> None:
         """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
         userinfo = {
             "sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements(self):
+    def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_contains(self):
+    def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             }
         }
     )
-    def test_attribute_requirements_mismatch(self):
+    def test_attribute_requirements_mismatch(self) -> None:
         """
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         auth_handler = self.hs.get_auth_handler()
         auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
-        userinfo = {
+        userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
 
     handler = hs.get_oidc_handler()
     provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
 
     state = "state"
     session = handler._token_generator.generate_oidc_session_token(