diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
index 1461d23ee8..d555b24255 100644
--- a/tests/test_utils/oidc.py
+++ b/tests/test_utils/oidc.py
@@ -14,7 +14,7 @@
import json
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, ContextManager, Dict, List, Optional, Tuple
from unittest.mock import Mock, patch
from urllib.parse import parse_qs
@@ -77,14 +77,14 @@ class FakeOidcServer:
self._id_token_overrides: Dict[str, Any] = {}
- def reset_mocks(self):
+ def reset_mocks(self) -> None:
self.request.reset_mock()
self.get_jwks_handler.reset_mock()
self.get_metadata_handler.reset_mock()
self.get_userinfo_handler.reset_mock()
self.post_token_handler.reset_mock()
- def patch_homeserver(self, hs: HomeServer):
+ def patch_homeserver(self, hs: HomeServer) -> ContextManager[Mock]:
"""Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
This patch should be used whenever the HS is expected to perform request to the
@@ -188,7 +188,7 @@ class FakeOidcServer:
return self._sign(logout_token)
- def id_token_override(self, overrides: dict):
+ def id_token_override(self, overrides: dict) -> ContextManager[dict]:
"""Temporarily patch the ID token generated by the token endpoint."""
return patch.object(self, "_id_token_overrides", overrides)
@@ -247,7 +247,7 @@ class FakeOidcServer:
metadata: bool = False,
token: bool = False,
userinfo: bool = False,
- ):
+ ) -> ContextManager[Dict[str, Mock]]:
"""A context which makes a set of endpoints return a 500 error.
Args:
|