summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py48
1 files changed, 31 insertions, 17 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5955410524..49a1842b5c 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import os
-from typing import Any, Dict, Tuple
+from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
@@ -23,7 +23,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.handlers.sso import MappingException
 from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
 from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 from synapse.util.stringutils import random_string
@@ -34,6 +34,10 @@ from tests.unittest import HomeserverTestCase, override_config
 
 try:
     import authlib  # noqa: F401
+    from authlib.oidc.core import UserInfo
+    from authlib.oidc.discovery import OpenIDProviderMetadata
+
+    from synapse.handlers.oidc import Token, UserAttributeDict
 
     HAS_OIDC = True
 except ImportError:
@@ -70,29 +74,37 @@ EXPLICIT_ENDPOINT_CONFIG = {
 
 class TestMappingProvider:
     @staticmethod
-    def parse_config(config):
-        return
+    def parse_config(config: JsonDict) -> None:
+        return None
 
-    def __init__(self, config):
+    def __init__(self, config: None):
         pass
 
-    def get_remote_user_id(self, userinfo):
+    def get_remote_user_id(self, userinfo: "UserInfo") -> str:
         return userinfo["sub"]
 
-    async def map_user_attributes(self, userinfo, token):
-        return {"localpart": userinfo["username"], "display_name": None}
+    async def map_user_attributes(
+        self, userinfo: "UserInfo", token: "Token"
+    ) -> "UserAttributeDict":
+        # This is testing not providing the full map.
+        return {"localpart": userinfo["username"], "display_name": None}  # type: ignore[typeddict-item]
 
     # Do not include get_extra_attributes to test backwards compatibility paths.
 
 
 class TestMappingProviderExtra(TestMappingProvider):
-    async def get_extra_attributes(self, userinfo, token):
+    async def get_extra_attributes(
+        self, userinfo: "UserInfo", token: "Token"
+    ) -> JsonDict:
         return {"phone": userinfo["phone"]}
 
 
 class TestMappingProviderFailures(TestMappingProvider):
-    async def map_user_attributes(self, userinfo, token, failures):
-        return {
+    # Superclass is testing the legacy interface for map_user_attributes.
+    async def map_user_attributes(  # type: ignore[override]
+        self, userinfo: "UserInfo", token: "Token", failures: int
+    ) -> "UserAttributeDict":
+        return {  # type: ignore[typeddict-item]
             "localpart": userinfo["username"] + (str(failures) if failures else ""),
             "display_name": None,
         }
@@ -161,13 +173,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.hs_patcher.stop()
         return super().tearDown()
 
-    def reset_mocks(self):
+    def reset_mocks(self) -> None:
         """Reset all the Mocks."""
         self.fake_server.reset_mocks()
         self.render_error.reset_mock()
         self.complete_sso_login.reset_mock()
 
-    def metadata_edit(self, values):
+    def metadata_edit(self, values: dict) -> ContextManager[Mock]:
         """Modify the result that will be returned by the well-known query"""
 
         metadata = self.fake_server.get_metadata()
@@ -196,7 +208,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
         return _build_callback_request(code, state, session), grant
 
-    def assertRenderedError(self, error, error_description=None):
+    def assertRenderedError(
+        self, error: str, error_description: Optional[str] = None
+    ) -> Tuple[Any, ...]:
         self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
@@ -273,8 +287,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadatas are extensively validated."""
         h = self.provider
 
-        def force_load_metadata():
-            async def force_load():
+        def force_load_metadata() -> Awaitable[None]:
+            async def force_load() -> "OpenIDProviderMetadata":
                 return await h.load_metadata(force=True)
 
             return get_awaitable_result(force_load())
@@ -1198,7 +1212,7 @@ def _build_callback_request(
     state: str,
     session: str,
     ip_address: str = "10.0.0.1",
-):
+) -> Mock:
     """Builds a fake SynapseRequest to mock the browser callback
 
     Returns a Mock object which looks like the SynapseRequest we get from a browser