diff options
Diffstat (limited to 'tests/handlers')
-rw-r--r-- | tests/handlers/test_oidc.py | 12 | ||||
-rw-r--r-- | tests/handlers/test_saml.py | 132 |
2 files changed, 90 insertions, 54 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9878527bab..464e569ac8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID -from tests.test_utils import FakeResponse +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index d21e5588ca..69927cf6be 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri" ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "" + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) |