diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5955410524..49a1842b5c 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
-from typing import Any, Dict, Tuple
+from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
try:
import authlib # noqa: F401
+ from authlib.oidc.core import UserInfo
+ from authlib.oidc.discovery import OpenIDProviderMetadata
+
+ from synapse.handlers.oidc import Token, UserAttributeDict
HAS_OIDC = True
except ImportError:
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
class TestMappingProvider:
@staticmethod
- def parse_config(config):
- return
+ def parse_config(config: JsonDict) -> None:
+ return None
- def __init__(self, config):
+ def __init__(self, config: None):
pass
- def get_remote_user_id(self, userinfo):
+ def get_remote_user_id(self, userinfo: "UserInfo") -> str:
return userinfo["sub"]
- async def map_user_attributes(self, userinfo, token):
- return {"localpart": userinfo["username"], "display_name": None}
+ async def map_user_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> "UserAttributeDict":
+ # This is testing not providing the full map.
+ return {"localpart": userinfo["username"], "display_name": None} # type: ignore[typeddict-item]
# Do not include get_extra_attributes to test backwards compatibility paths.
class TestMappingProviderExtra(TestMappingProvider):
- async def get_extra_attributes(self, userinfo, token):
+ async def get_extra_attributes(
+ self, userinfo: "UserInfo", token: "Token"
+ ) -> JsonDict:
return {"phone": userinfo["phone"]}
class TestMappingProviderFailures(TestMappingProvider):
- async def map_user_attributes(self, userinfo, token, failures):
- return {
+ # Superclass is testing the legacy interface for map_user_attributes.
+ async def map_user_attributes( # type: ignore[override]
+ self, userinfo: "UserInfo", token: "Token", failures: int
+ ) -> "UserAttributeDict":
+ return { # type: ignore[typeddict-item]
"localpart": userinfo["username"] + (str(failures) if failures else ""),
"display_name": None,
}
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.hs_patcher.stop()
return super().tearDown()
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
"""Reset all the Mocks."""
self.fake_server.reset_mocks()
self.render_error.reset_mock()
self.complete_sso_login.reset_mock()
- def metadata_edit(self, values):
+ def metadata_edit(self, values: dict) -> ContextManager[Mock]:
"""Modify the result that will be returned by the well-known query"""
metadata = self.fake_server.get_metadata()
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
return _build_callback_request(code, state, session), grant
- def assertRenderedError(self, error, error_description=None):
+ def assertRenderedError(
+ self, error: str, error_description: Optional[str] = None
+ ) -> Tuple[Any, ...]:
self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadatas are extensively validated."""
h = self.provider
- def force_load_metadata():
- async def force_load():
+ def force_load_metadata() -> Awaitable[None]:
+ async def force_load() -> "OpenIDProviderMetadata":
return await h.load_metadata(force=True)
return get_awaitable_result(force_load())
@@ -1198,7 +1212,7 @@ def _build_callback_request(
state: str,
session: str,
ip_address: str = "10.0.0.1",
-):
+) -> Mock:
"""Builds a fake SynapseRequest to mock the browser callback
Returns a Mock object which looks like the SynapseRequest we get from a browser
|