summary refs log tree commit diff
path: root/tests/handlers/test_oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_oidc.py')
-rw-r--r--tests/handlers/test_oidc.py36
1 files changed, 16 insertions, 20 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..02d4b2de0d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "state"
-        )
-        nonce = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "nonce"
-        )
-        redirect = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
+        state = get_value_from_macaroon(macaroon, "state")
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=True
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=False
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler.complete_sso_login.assert_called_once_with(
             "@foo:test",
+            "oidc",
             request,
             client_redirect_url,
             {"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None, new_user=True
+            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None, new_user=True
+            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None, new_user=True
+            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state: str,
         nonce: str,
         client_redirect_url: str,
-        ui_auth_session_id: Optional[str] = None,
+        ui_auth_session_id: str = "",
     ) -> str:
         from synapse.handlers.oidc_handler import OidcSessionData
 
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
             idp_id="oidc",
             nonce="nonce",
             client_redirect_url=client_redirect_url,
+            ui_auth_session_id="",
         ),
     )
     request = _build_callback_request("code", state, session)