diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e8418b6638..014815db6e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,14 +13,18 @@
# limitations under the License.
import json
import os
+from typing import Any, Dict
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
@@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url):
+async def get_json(url: str) -> JsonDict:
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
# Minimal discovery document, as defined in OpenID.Discovery
@@ -116,6 +120,8 @@ async def get_json(url):
elif url == JWKS_URI:
return {"keys": []}
+ return {}
+
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
- def default_config(self):
+ def default_config(self) -> Dict[str, Any]:
config = super().default_config()
config["public_baseurl"] = BASE_URL
return config
- def make_homeserver(self, reactor, clock):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.http_client = Mock(spec=["get_json"])
self.http_client.get_json.side_effect = get_json
self.http_client.user_agent = b"Synapse Test"
@@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error # type: ignore[assignment]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
@@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
return args
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_config(self):
+ def test_config(self) -> None:
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
- def test_discovery(self):
+ def test_discovery(self) -> None:
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
@@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_no_discovery(self):
+ def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
- def test_load_jwks(self):
+ def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
@@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_validate_config(self):
+ def test_validate_config(self) -> None:
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
force_load_metadata()
@override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
- def test_skip_verification(self):
+ def test_skip_verification(self) -> None:
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_redirect_request(self):
+ def test_redirect_request(self) -> None:
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
req.cookies = []
@@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(redirect, "http://client/redirect")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_error(self):
+ def test_callback_error(self) -> None:
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
@@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_client", "some description")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback(self):
+ def test_callback(self) -> None:
"""Code callback works and display errors if something went wrong.
A lot of scenarios are tested here:
@@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
@@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"type": "bearer",
"access_token": "access_token",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
@@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
id_token = {
"sid": "abcdefgh",
}
- self.provider._parse_id_token = simple_async_mock(return_value=id_token)
- self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
auth_handler.complete_sso_login.reset_mock()
self.provider._fetch_userinfo.reset_mock()
self.get_success(self.handler.handle_oidc_callback(request))
@@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc import OidcError
- self.provider._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_callback_session(self):
+ def test_callback_session(self) -> None:
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
)
- def test_exchange_code(self):
+ def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
token_json = json.dumps(token).encode("utf-8")
@@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_jwt_key(self):
+ def test_exchange_code_jwt_key(self) -> None:
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
@@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_exchange_code_no_auth(self):
+ def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
token = {"type": "bearer"}
self.http_client.request = simple_async_mock(
@@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_extra_attributes(self):
+ def test_extra_attributes(self) -> None:
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
@@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token)
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_user(self):
+ def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
- userinfo = {
+ userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
@@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
- def test_map_userinfo_to_existing_user(self):
+ def test_map_userinfo_to_existing_user(self) -> None:
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastores().main
user = UserID.from_string("@test_user:test")
@@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_map_userinfo_to_invalid_localpart(self):
+ def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
_make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
@@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_map_userinfo_to_user_retries(self):
+ def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
@override_config({"oidc_config": DEFAULT_CONFIG})
- def test_empty_localpart(self):
+ def test_empty_localpart(self) -> None:
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
"sub": "tester",
@@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_null_localpart(self):
+ def test_null_localpart(self) -> None:
"""Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
userinfo = {
"sub": "tester",
@@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements(self):
+ def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_contains(self):
+ def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
}
)
- def test_attribute_requirements_mismatch(self):
+ def test_attribute_requirements_mismatch(self) -> None:
"""
Test that auth fails if attributes exist but don't match,
or are non-string values.
@@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
- userinfo = {
+ userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
@@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo(
handler = hs.get_oidc_handler()
provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""})
- provider._parse_id_token = simple_async_mock(return_value=userinfo)
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
+ provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
state = "state"
session = handler._token_generator.generate_oidc_session_token(
|