diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d0452e1490..0b24b89a2e 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1d99a45436..464e569ac8 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -15,7 +15,7 @@
import json
from urllib.parse import parse_qs, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
import pymacaroons
@@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
-from tests.test_utils import FakeResponse
+from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
@@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
-
-
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -160,6 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
+ return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -374,26 +365,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
+ "username": username,
}
- user_id = "@foo:domain.org"
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -401,64 +383,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
+ request = self._build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.handler,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
self.handler._fetch_userinfo.reset_mock()
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
@@ -609,72 +581,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = "10.0.0.1"
- request.get_user_agent.return_value = "Browser"
+ request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test", request, client_redirect_url, {"phone": "1234567"},
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", ANY, ANY, {}
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", ANY, ANY, {}
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -683,14 +638,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(
- str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@@ -702,26 +654,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -732,13 +684,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -755,14 +705,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
self.assertTrue(
- str(e.value).startswith(
+ args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -773,28 +720,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", ANY, ANY, {},
)
- self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- userinfo = {
- "sub": "test2",
- "username": "föö",
- }
- token = {}
-
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
@@ -807,6 +741,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -815,14 +752,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
- )
+ self._make_callback_with_userinfo(userinfo)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", ANY, ANY, {},
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -838,12 +774,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ def _make_callback_with_userinfo(
+ self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+ ) -> None:
+ self.handler._exchange_code = simple_async_mock(return_value={})
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ state = "state"
+ session = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ request = self._build_callback_request("code", state, session)
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ def _build_callback_request(
+ self,
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+ ):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..8d50265145 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index d21e5588ca..69927cf6be 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from mock import Mock
+
import attr
from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
+from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
@@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- # The redirect_url doesn't matter with the default user mapping provider.
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri"
)
- self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, ""
)
- self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, ""
)
- self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- redirect_url = ""
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
+
+ # send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", request, ""
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
)
+ auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
+ request = _mock_request()
e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
+ self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "get_user_agent"])
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f6e7e5fdaa..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext(request="test"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"level",
"namespace",
"request",
- "scope",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test")
- self.assertIsNone(log["scope"])
diff --git a/tests/test_federation.py b/tests/test_federation.py
index fa45f8b3b7..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext(request="lying_event"):
+ with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 6873d45eb6..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,8 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+from mock import Mock
+
import attr
from twisted.python.failure import Failure
@@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
|