diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
)
- self.assertEqual("a_user", user_id)
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+ self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
+ def test_short_term_login_token_gives_auth_provider(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", auth_provider_id="my_idp"
+ )
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+
def test_short_term_login_token_cannot_replace_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
+ )
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ res = self.get_success(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
- self.assertEqual("a_user", user_id)
+ self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- ),
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
- token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", "", 5000
+ )
return pymacaroons.Macaroon.deserialize(token)
|