summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_oidc.py24
1 files changed, 18 insertions, 6 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index b6f436c016..0d51705849 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -394,7 +394,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
         request = Mock(
-            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+            spec=[
+                "args",
+                "getCookie",
+                "addCookie",
+                "requestHeaders",
+                "getClientIP",
+                "get_user_agent",
+            ]
         )
 
         code = "code"
@@ -414,9 +421,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [code.encode("utf-8")]
         request.args[b"state"] = [state.encode("utf-8")]
 
-        request.requestHeaders = Mock(spec=["getRawHeaders"])
-        request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
         request.getClientIP.return_value = ip_address
+        request.get_user_agent.return_value = user_agent
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
@@ -621,7 +627,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
         self.handler._auth_handler.complete_sso_login = simple_async_mock()
         request = Mock(
-            spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+            spec=[
+                "args",
+                "getCookie",
+                "addCookie",
+                "requestHeaders",
+                "getClientIP",
+                "get_user_agent",
+            ]
         )
 
         state = "state"
@@ -637,9 +650,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         request.args[b"code"] = [b"code"]
         request.args[b"state"] = [state.encode("utf-8")]
 
-        request.requestHeaders = Mock(spec=["getRawHeaders"])
-        request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
         request.getClientIP.return_value = "10.0.0.1"
+        request.get_user_agent.return_value = "Browser"
 
         self.get_success(self.handler.handle_oidc_callback(request))