summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_oidc.py14
1 files changed, 7 insertions, 7 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0d51705849..630e6da808 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -154,6 +154,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
         self.handler = OidcHandler(hs)
+        # Mock the render error method.
+        self.render_error = Mock(return_value=None)
+        self.handler._sso_handler.render_error = self.render_error
 
         return hs
 
@@ -161,12 +164,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return patch.dict(self.handler._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
-        args = self.handler._render_error.call_args[0]
+        args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
-        self.handler._render_error.reset_mock()
+        self.render_error.reset_mock()
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
@@ -356,7 +359,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
-        self.handler._render_error = Mock()
         request = Mock(args={})
         request.args[b"error"] = [b"invalid_client"]
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -387,7 +389,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "preferred_username": "bar",
         }
         user_id = "@foo:domain.org"
-        self.handler._render_error = Mock(return_value=None)
         self.handler._exchange_code = simple_async_mock(return_value=token)
         self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
         self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
@@ -435,7 +436,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             userinfo, token, user_agent, ip_address
         )
         self.handler._fetch_userinfo.assert_not_called()
-        self.handler._render_error.assert_not_called()
+        self.render_error.assert_not_called()
 
         # Handle mapping errors
         self.handler._map_userinfo_to_user = simple_async_mock(
@@ -469,7 +470,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             userinfo, token, user_agent, ip_address
         )
         self.handler._fetch_userinfo.assert_called_once_with(token)
-        self.handler._render_error.assert_not_called()
+        self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
         self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -485,7 +486,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
-        self.handler._render_error = Mock(return_value=None)
         request = Mock(spec=["args", "getCookie", "addCookie"])
 
         # Missing cookie