diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8
new file mode 100644
index 0000000000..bb92976333
--- /dev/null
+++ b/tests/handlers/oidc_test_key.p8
@@ -0,0 +1,5 @@
+-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
+Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
+P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
+-----END PRIVATE KEY-----
diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem
new file mode 100644
index 0000000000..176d4a4b4b
--- /dev/null
+++ b/tests/handlers/oidc_test_key.pub.pem
@@ -0,0 +1,4 @@
+-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
+QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
+-----END PUBLIC KEY-----
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
)
- self.assertEqual("a_user", user_id)
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+ self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
+ def test_short_term_login_token_gives_auth_provider(self):
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", auth_provider_id="my_idp"
+ )
+ res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+ self.assertEqual("a_user", res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+
def test_short_term_login_token_cannot_replace_user_id(self):
- token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "a_user", "", 5000
+ )
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- )
+ res = self.get_success(
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
- self.assertEqual("a_user", user_id)
+ self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
- ),
+ self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
- self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
- token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+ token = self.macaroon_generator.generate_short_term_login_token(
+ "user_a", "", 5000
+ )
return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=False
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+ "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index cf1de28fa9..5e9c9c2e88 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
-from typing import Optional
+import os
from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch
@@ -23,6 +23,7 @@ import pymacaroons
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
JWKS_URI = ISSUER + ".well-known/jwks.json"
# config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+ **DEFAULT_CONFIG,
"discover": False,
"authorization_endpoint": AUTHORIZATION_ENDPOINT,
"token_endpoint": TOKEN_ENDPOINT,
@@ -107,6 +119,32 @@ async def get_json(url):
return {"keys": []}
+def _key_file_path() -> str:
+ """path to a file containing the private half of a test key"""
+
+ # this key was generated with:
+ # openssl ecparam -name prime256v1 -genkey -noout |
+ # openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+ #
+ # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+ # that's what Apple use, and we want to be sure that we work with Apple's keys.
+ #
+ # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+ # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+ # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+ # really need to care about any of that.)
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+ """path to a file containing the public half of a test key"""
+ # this was generated with:
+ # openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+ #
+ # See above about where oidc_test_key.p8 came from
+ return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
class OidcHandlerTestCase(HomeserverTestCase):
if not HAS_OIDC:
skip = "requires OIDC"
@@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {
- "enabled": True,
- "client_id": CLIENT_ID,
- "client_secret": CLIENT_SECRET,
- "issuer": ISSUER,
- "scopes": SCOPES,
- "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
- }
-
- # Update this config with what's in the default config so that
- # override_config works as expected.
- oidc_config.update(config.get("oidc_config", {}))
- config["oidc_config"] = oidc_config
-
return config
def make_homeserver(self, reactor, clock):
@@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.render_error.reset_mock()
return args
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
self.assertEqual(self.provider._callback_url, CALLBACK_URL)
self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
- @override_config({"oidc_config": {"discover": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
@@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
- @override_config({"oidc_config": COMMON_CONFIG})
+ @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
@@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
h = self.provider
@@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Shouldn't raise with a valid userinfo, even without jwks
force_load_metadata()
- @override_config({"oidc_config": {"skip_verification": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
def test_skip_verification(self):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
get_awaitable_result(self.provider.load_metadata())
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["cookies"])
@@ -360,20 +387,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "state"
- )
- nonce = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "nonce"
- )
- redirect = self.handler._token_generator._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
+ state = get_value_from_macaroon(macaroon, "state")
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
request = Mock(args={})
@@ -385,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -434,7 +457,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=True
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, None, new_user=False
+ expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@@ -486,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
request = Mock(spec=["args", "getCookie", "cookies"])
@@ -528,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+ @override_config(
+ {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+ )
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -613,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "client_secret_post",
+ "client_secret_jwt_key": {
+ "key_file": _key_file_path(),
+ "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+ "jwt_payload": {"iss": "DEFGHI"},
+ },
+ }
+ }
+ )
+ def test_exchange_code_jwt_key(self):
+ """Test that code exchange works with a JWK client secret."""
+ from authlib.jose import jwt
+
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+
+ # advance the clock a bit before we start, so we aren't working with zero
+ # timestamps.
+ self.reactor.advance(1000)
+ start_time = self.reactor.seconds()
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # the client secret provided to the should be a jwt which can be checked with
+ # the public key
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ secret = args["client_secret"][0]
+ with open(_public_key_file_path()) as f:
+ key = f.read()
+ claims = jwt.decode(secret, key)
+ self.assertEqual(claims.header["kid"], "ABC789")
+ self.assertEqual(claims["aud"], ISSUER)
+ self.assertEqual(claims["iss"], "DEFGHI")
+ self.assertEqual(claims["sub"], CLIENT_ID)
+ self.assertEqual(claims["iat"], start_time)
+ self.assertGreater(claims["exp"], start_time)
+
+ # check the rest of the POSTed data
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "issuer": ISSUER,
+ "client_auth_method": "none",
+ }
+ }
+ )
+ def test_exchange_code_no_auth(self):
+ """Test that code exchange works with no client secret."""
+ token = {"type": "bearer"}
+ self.http_client.request = simple_async_mock(
+ return_value=FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+ )
+ )
+ code = "code"
+ ret = self.get_success(self.provider._exchange_code(code))
+
+ self.assertEqual(ret, token)
+
+ # the request should have hit the token endpoint
+ kwargs = self.http_client.request.call_args[1]
+ self.assertEqual(kwargs["method"], "POST")
+ self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+ # check the POSTed data
+ args = parse_qs(kwargs["data"].decode("utf-8"))
+ self.assertEqual(args["grant_type"], ["authorization_code"])
+ self.assertEqual(args["code"], [code])
+ self.assertEqual(args["client_id"], [CLIENT_ID])
+ self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+ @override_config(
+ {
+ "oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderExtra"
- }
+ },
}
}
)
@@ -651,12 +773,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test",
+ "oidc",
request,
client_redirect_url,
{"phone": "1234567"},
new_user=True,
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
auth_handler = self.hs.get_auth_handler()
@@ -668,7 +792,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", ANY, ANY, None, new_user=True
+ "@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -679,7 +803,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test", ANY, ANY, None, new_user=True
+ "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -697,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"Mapping provider does not support de-duplicating Matrix IDs",
)
- @override_config({"oidc_config": {"allow_existing_users": True}})
+ @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
@@ -716,14 +840,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -738,7 +862,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, None, new_user=False
+ user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -774,9 +898,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+ "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
self.get_success(
@@ -787,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"module": __name__ + ".TestMappingProviderFailures"
- }
+ },
}
}
)
@@ -810,7 +936,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", ANY, ANY, None, new_user=True
+ "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -834,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_empty_localpart(self):
"""Attempts to map onto an empty localpart should be rejected."""
userinfo = {
@@ -846,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config(
{
"oidc_config": {
+ **DEFAULT_CONFIG,
"user_mapping_provider": {
"config": {"localpart_template": "{{ user.username }}"}
- }
+ },
}
}
)
@@ -866,7 +994,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
state: str,
nonce: str,
client_redirect_url: str,
- ui_auth_session_id: Optional[str] = None,
+ ui_auth_session_id: str = "",
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
@@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo(
idp_id="oidc",
nonce="nonce",
client_redirect_url=client_redirect_url,
+ ui_auth_session_id="",
),
)
request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "", None, new_user=False
+ "@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", request, "", None, new_user=True
+ "@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", request, "redirect_uri", None, new_user=True
+ "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
|