summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-04 14:44:22 +0000
committerGitHub <noreply@github.com>2021-03-04 14:44:22 +0000
commit7eb6e39a8fe9d42a411cefd905cf2caa29896923 (patch)
treeddcf4fc4eb801299d2e6191c7f34af2d3741c066 /tests
parentFix link in UPGRADES (diff)
downloadsynapse-7eb6e39a8fe9d42a411cefd905cf2caa29896923.tar.xz
Record the SSO Auth Provider in the login token (#9510)
This great big stack of commits is a a whole load of hoop-jumping to make it easier to store additional values in login tokens, and then to actually store the SSO Identity Provider in the login token. (Making use of that data will follow in a subsequent PR.)
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_auth.py49
-rw-r--r--tests/handlers/test_cas.py10
-rw-r--r--tests/handlers/test_oidc.py36
-rw-r--r--tests/handlers/test_saml.py10
4 files changed, 55 insertions, 50 deletions
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
     def test_short_term_login_token_gives_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
         )
-        self.assertEqual("a_user", user_id)
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("", res.auth_provider_id)
 
         # when we advance the clock, the token should be rejected
         self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            self.auth_handler.validate_short_term_login_token(token),
             AuthError,
         )
 
+    def test_short_term_login_token_gives_auth_provider(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", auth_provider_id="my_idp"
+        )
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+
     def test_short_term_login_token_cannot_replace_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
+        )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+        res = self.get_success(
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
         )
-        self.assertEqual("a_user", user_id)
+        self.assertEqual("a_user", res.user_id)
 
         # add another "user_id" caveat, which might allow us to override the
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            ),
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
             AuthError,
         )
 
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.large_number_of_users)
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
 
     def _get_macaroon(self):
-        token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", "", 5000
+        )
         return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+            "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     @override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..02d4b2de0d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "state"
-        )
-        nonce = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "nonce"
-        )
-        redirect = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
+        state = get_value_from_macaroon(macaroon, "state")
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=True
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=False
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler.complete_sso_login.assert_called_once_with(
             "@foo:test",
+            "oidc",
             request,
             client_redirect_url,
             {"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None, new_user=True
+            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None, new_user=True
+            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
     def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None, new_user=True
+            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state: str,
         nonce: str,
         client_redirect_url: str,
-        ui_auth_session_id: Optional[str] = None,
+        ui_auth_session_id: str = "",
     ) -> str:
         from synapse.handlers.oidc_handler import OidcSessionData
 
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
             idp_id="oidc",
             nonce="nonce",
             client_redirect_url=client_redirect_url,
+            ui_auth_session_id="",
         ),
     )
     request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None, new_user=True
+            "@test_user1:test", "saml", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )