diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 464e569ac8..c54f1c5797 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -19,8 +19,9 @@ from mock import ANY, Mock, patch
import pymacaroons
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
+from synapse.handlers.oidc_handler import OidcError
from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
from synapse.types import UserID
from tests.test_utils import FakeResponse, simple_async_mock
@@ -55,11 +56,14 @@ COOKIE_NAME = b"oidc_session"
COOKIE_PATH = "/_synapse/oidc"
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
@staticmethod
def parse_config(config):
return
+ def __init__(self, config):
+ pass
+
def get_remote_user_id(self, userinfo):
return userinfo["sub"]
@@ -360,6 +364,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails
- when the code exchange fails
"""
+
+ # ensure that we are correctly testing the fallback when "get_extra_attributes"
+ # is not implemented.
+ mapping_provider = self.handler._user_mapping_provider
+ with self.assertRaises(AttributeError):
+ _ = mapping_provider.get_extra_attributes
+
token = {
"type": "bearer",
"id_token": "id_token",
@@ -389,14 +400,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request = self._build_callback_request(
+ request = _build_callback_request(
code, state, session, user_agent=user_agent, ip_address=ip_address
)
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, {},
+ expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,7 +438,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
- expected_user_id, request, client_redirect_url, {},
+ expected_user_id, request, client_redirect_url, None,
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
@@ -597,7 +608,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request = self._build_callback_request("code", state, session)
+ request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -614,9 +625,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test_user",
"username": "test_user",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user:test", ANY, ANY, {}
+ "@test_user:test", ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
@@ -625,9 +636,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": 1234,
"username": "test_user_2",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user_2:test", ANY, ANY, {}
+ "@test_user_2:test", ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
@@ -638,7 +649,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
@@ -662,16 +673,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, {},
+ user.to_string(), ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, {},
+ user.to_string(), ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
@@ -684,9 +695,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- user.to_string(), ANY, ANY, {},
+ user.to_string(), ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
@@ -705,7 +716,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
@@ -720,14 +731,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
- "@TEST_USER_2:test", ANY, ANY, {},
+ "@TEST_USER_2:test", ANY, ANY, None,
)
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+ self.get_success(
+ _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
+ )
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@@ -752,11 +765,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
- "@test_user1:test", ANY, ANY, {},
+ "@test_user1:test", ANY, ANY, None,
)
auth_handler.complete_sso_login.reset_mock()
@@ -774,70 +787,80 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- self._make_callback_with_userinfo(userinfo)
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
- def _make_callback_with_userinfo(
- self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
- ) -> None:
- self.handler._exchange_code = simple_async_mock(return_value={})
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
- state = "state"
- session = self.handler._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
- )
- request = self._build_callback_request("code", state, session)
+async def _make_callback_with_userinfo(
+ hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+ """Mock up an OIDC callback with the given userinfo dict
- self.get_success(self.handler.handle_oidc_callback(request))
+ We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+ and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
- def _build_callback_request(
- self,
- code: str,
- state: str,
- session: str,
- user_agent: str = "Browser",
- ip_address: str = "10.0.0.1",
- ):
- """Builds a fake SynapseRequest to mock the browser callback
-
- Returns a Mock object which looks like the SynapseRequest we get from a browser
- after SSO (before we return to the client)
-
- Args:
- code: the authorization code which would have been returned by the OIDC
- provider
- state: the "state" param which would have been passed around in the
- query param. Should be the same as was embedded in the session in
- _build_oidc_session.
- session: the "session" which would have been passed around in the cookie.
- user_agent: the user-agent to present
- ip_address: the IP address to pretend the request came from
- """
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ Args:
+ hs: the HomeServer impl to send the callback to.
+ userinfo: the OIDC userinfo dict
+ client_redirect_url: the URL to redirect to on success.
+ """
+ handler = hs.get_oidc_handler()
+ handler._exchange_code = simple_async_mock(return_value={})
+ handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- request.getCookie.return_value = session
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
- return request
+ state = "state"
+ session = handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+ request = _build_callback_request("code", state, session)
+
+ await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
+ )
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
+ return request
|