From 76469898ee797db232adaccb9fd547bddab2fe59 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Wed, 2 Dec 2020 18:22:01 +0000 Subject: Factor out FakeResponse from test_oidc --- tests/test_utils/__init__.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) (limited to 'tests/test_utils') diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index d232b72264..6873d45eb6 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,11 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +import attr + +from twisted.python.failure import Failure +from twisted.web.client import ResponseDone + TV = TypeVar("TV") @@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]: sys.unraisablehook = unraisablehook # type: ignore return cleanup + + +@attr.s +class FakeResponse: + """A fake twisted.web.IResponse object + + there is a similar class at treq.test.test_response, but it lacks a `phrase` + attribute, and didn't support deliverBody until recently. + """ + + # HTTP response code + code = attr.ib(type=int) + + # HTTP response phrase (eg b'OK' for a 200) + phrase = attr.ib(type=bytes) + + # body of the response + body = attr.ib(type=bytes) + + def deliverBody(self, protocol): + protocol.dataReceived(self.body) + protocol.connectionLost(Failure(ResponseDone())) -- cgit 1.5.1 From 1619802228033455ff6e5863c52556996b38e8c6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Dec 2020 14:19:47 -0500 Subject: Various clean-ups to the logging context code (#8935) --- changelog.d/8916.misc | 2 +- changelog.d/8935.misc | 1 + synapse/config/logger.py | 2 +- synapse/http/site.py | 3 +-- synapse/logging/context.py | 24 +++++------------------- synapse/metrics/background_process_metrics.py | 7 +++---- synapse/replication/tcp/protocol.py | 3 +-- tests/handlers/test_federation.py | 6 +++--- tests/logging/test_terse_json.py | 7 ++----- tests/test_federation.py | 2 +- tests/test_utils/logging_setup.py | 2 +- 11 files changed, 20 insertions(+), 39 deletions(-) create mode 100644 changelog.d/8935.misc (limited to 'tests/test_utils') diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc index c71ef480e6..bf94135fd5 100644 --- a/changelog.d/8916.misc +++ b/changelog.d/8916.misc @@ -1 +1 @@ -Improve structured logging tests. +Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8935.misc b/changelog.d/8935.misc new file mode 100644 index 0000000000..bf94135fd5 --- /dev/null +++ b/changelog.d/8935.misc @@ -0,0 +1 @@ +Various clean-ups to the structured logging and logging context code. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index d4e887a3e0..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_context_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() diff --git a/synapse/http/site.py b/synapse/http/site.py index 5f0581dc3f..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -128,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 76b7decf26..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: ctx = noop_context_manager() if bg_start_span: @@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d0452e1490..0b24b89a2e 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): # the auth code requires that a signature exists, but doesn't check that # signature... go figure. join_event.signatures[other_server] = {"x": "y"} - with LoggingContext(request="send_join"): + with LoggingContext("send_join"): d = run_in_background( self.handler.on_send_join_request, other_server, join_event ) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index f6e7e5fdaa..48a74e2eee 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ handler = logging.StreamHandler(self.output) handler.setFormatter(JsonFormatter()) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext(request="test"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): "level", "namespace", "request", - "scope", ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "test") - self.assertIsNone(log["scope"]) diff --git a/tests/test_federation.py b/tests/test_federation.py index fa45f8b3b7..fc9aab32d0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(request="lying_event"): + with LoggingContext(): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fdfb840b62..52ae5c5713 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -48,7 +48,7 @@ def setup_logging(): handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) root_logger.addHandler(handler) log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") -- cgit 1.5.1 From 01333681bc3db22541b49c194f5121a5415731c6 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 15 Dec 2020 20:56:10 +0000 Subject: Preparatory refactoring of the SamlHandlerTestCase (#8938) * move simple_async_mock to test_utils ... so that it can be re-used * Remove references to `SamlHandler._map_saml_response_to_user` from tests This method is going away, so we can no longer use it as a test point. Instead, factor out a higher-level method which takes a SAML object, and verify correct behaviour by mocking out `AuthHandler.complete_sso_login`. * changelog --- changelog.d/8938.feature | 1 + synapse/handlers/saml_handler.py | 23 +++++++ tests/handlers/test_oidc.py | 12 +--- tests/handlers/test_saml.py | 132 ++++++++++++++++++++++++++------------- tests/test_utils/__init__.py | 12 ++++ 5 files changed, 126 insertions(+), 54 deletions(-) create mode 100644 changelog.d/8938.feature (limited to 'tests/test_utils') diff --git a/changelog.d/8938.feature b/changelog.d/8938.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8938.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index f2ca1ddb53..6001fe3e27 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -163,6 +163,29 @@ class SamlHandler(BaseHandler): return logger.debug("SAML2 response: %s", saml2_auth.origxml) + + await self._handle_authn_response(request, saml2_auth, relay_state) + + async def _handle_authn_response( + self, + request: SynapseRequest, + saml2_auth: saml2.response.AuthnResponse, + relay_state: str, + ) -> None: + """Handle an AuthnResponse, having parsed it from the request params + + Assumes that the signature on the response object has been checked. Maps + the user onto an MXID, registering them if necessary, and returns a response + to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + saml2_auth: the parsed AuthnResponse object + relay_state: the RelayState query param, which encodes the URI to rediret + back to + """ + for assertion in saml2_auth.assertions: # kibana limits the length of a log field, whereas this is all rather # useful, so split it up. diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9878527bab..464e569ac8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID -from tests.test_utils import FakeResponse +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index d21e5588ca..69927cf6be 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri" ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "" + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index 6873d45eb6..43898d8142 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -22,6 +22,8 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +from mock import Mock + import attr from twisted.python.failure import Failure @@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup +def simple_async_mock(return_value=None, raises=None) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + @attr.s class FakeResponse: """A fake twisted.web.IResponse object -- cgit 1.5.1 From 02070c69faa47bf6aef280939c2d5f32cbcb9f25 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 18 Jan 2021 14:52:49 +0000 Subject: Fix bugs in handling clientRedirectUrl, and improve OIDC tests (#9127, #9128) * Factor out a common TestHtmlParser Looks like I'm doing this in a few different places. * Improve OIDC login test Complete the OIDC login flow, rather than giving up halfway through. * Ensure that OIDC login works with multiple OIDC providers * Fix bugs in handling clientRedirectUrl - don't drop duplicate query-params, or params with no value - allow utf-8 in query-params --- changelog.d/9127.feature | 1 + changelog.d/9128.bugfix | 1 + synapse/handlers/auth.py | 4 +- synapse/handlers/oidc_handler.py | 2 +- synapse/rest/synapse/client/pick_idp.py | 4 +- tests/rest/client/v1/test_login.py | 146 ++++++++++++++++++++------------ tests/rest/client/v1/utils.py | 62 ++++++++------ tests/server.py | 2 +- tests/test_utils/html_parsers.py | 53 ++++++++++++ 9 files changed, 189 insertions(+), 86 deletions(-) create mode 100644 changelog.d/9127.feature create mode 100644 changelog.d/9128.bugfix create mode 100644 tests/test_utils/html_parsers.py (limited to 'tests/test_utils') diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9127.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix new file mode 100644 index 0000000000..f87b9fb9aa --- /dev/null +++ b/changelog.d/9128.bugfix @@ -0,0 +1 @@ +Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 18cd2b62f0..0e98db22b3 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1504,8 +1504,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 5e5fda7b2f..ba686d74b2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -85,7 +85,7 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._providers = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } + } # type: Dict[str, OidcProvider] async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py @@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 73a009efd1..2d25490374 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -15,9 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -38,6 +37,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -69,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, ) + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( @@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # whitelist this client URI so we redirect straight to it rather than # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + config["sso"] = {"client_whitelist": ["https://x"]} return config def create_resource_dict(self) -> Dict[str, Resource]: @@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self): """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" # do the start of the login flow channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL ) # that should redirect to the username picker @@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase): session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) @@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] + login_token = params[2][1] # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": ""}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/server.py b/tests/server.py index 5a1b66270f..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py @@ -74,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644 index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) -- cgit 1.5.1