summary refs log tree commit diff
path: root/tests/test_utils/oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_utils/oidc.py')
-rw-r--r--tests/test_utils/oidc.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
 
 
 import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
 from unittest.mock import Mock, patch
 from urllib.parse import parse_qs
 
@@ -77,14 +77,14 @@ class FakeOidcServer:
 
         self._id_token_overrides: Dict[str, Any] = {}
 
-    def reset_mocks(self):
+    def reset_mocks(self) -> None:
         self.request.reset_mock()
         self.get_jwks_handler.reset_mock()
         self.get_metadata_handler.reset_mock()
         self.get_userinfo_handler.reset_mock()
         self.post_token_handler.reset_mock()
 
-    def patch_homeserver(self, hs: HomeServer):
+    def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
         """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
 
         This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
 
         return self._sign(logout_token)
 
-    def id_token_override(self, overrides: dict):
+    def id_token_override(self, overrides: dict) -> ContextManager[dict]:
         """Temporarily patch the ID token generated by the token endpoint."""
         return patch.object(self, "_id_token_overrides", overrides)
 
@@ -247,7 +247,7 @@ class FakeOidcServer:
         metadata: bool = False,
         token: bool = False,
         userinfo: bool = False,
-    ):
+    ) -> ContextManager[Dict[str, Mock]]:
         """A context which makes a set of endpoints return a 500 error.
 
         Args: