diff options
-rw-r--r-- | changelog.d/9617.feature | 1 | ||||
-rw-r--r-- | docs/openid.md | 8 | ||||
-rw-r--r-- | docs/sample_config.yaml | 2 | ||||
-rw-r--r-- | synapse/config/oidc_config.py | 13 | ||||
-rw-r--r-- | synapse/handlers/cas_handler.py | 1 | ||||
-rw-r--r-- | synapse/handlers/oidc_handler.py | 3 | ||||
-rw-r--r-- | synapse/handlers/saml_handler.py | 1 | ||||
-rw-r--r-- | synapse/handlers/sso.py | 5 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 39 | ||||
-rw-r--r-- | tests/rest/client/v1/test_login.py | 43 |
10 files changed, 88 insertions, 28 deletions
diff --git a/changelog.d/9617.feature b/changelog.d/9617.feature new file mode 100644 index 0000000000..b462a32b92 --- /dev/null +++ b/changelog.d/9617.feature @@ -0,0 +1 @@ +Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/docs/openid.md b/docs/openid.md index 01205d1220..cfaafc5015 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -226,7 +226,7 @@ Synapse config: oidc_providers: - idp_id: github idp_name: Github - idp_brand: "org.matrix.github" # optional: styling hint for clients + idp_brand: "github" # optional: styling hint for clients discover: false issuer: "https://github.com/" client_id: "your-client-id" # TO BE FILLED @@ -252,7 +252,7 @@ oidc_providers: oidc_providers: - idp_id: google idp_name: Google - idp_brand: "org.matrix.google" # optional: styling hint for clients + idp_brand: "google" # optional: styling hint for clients issuer: "https://accounts.google.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -299,7 +299,7 @@ Synapse config: oidc_providers: - idp_id: gitlab idp_name: Gitlab - idp_brand: "org.matrix.gitlab" # optional: styling hint for clients + idp_brand: "gitlab" # optional: styling hint for clients issuer: "https://gitlab.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -334,7 +334,7 @@ Synapse config: ```yaml - idp_id: facebook idp_name: Facebook - idp_brand: "org.matrix.facebook" # optional: styling hint for clients + idp_brand: "facebook" # optional: styling hint for clients discover: false issuer: "https://facebook.com" client_id: "your-client-id" # TO BE FILLED diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 41ab35595b..7de000f4a4 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1919,7 +1919,7 @@ oidc_providers: # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 7f5e449eb2..2bfb537c15 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -237,7 +237,7 @@ class OIDCConfig(Config): # #- idp_id: github # idp_name: Github - # idp_brand: org.matrix.github + # idp_brand: github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "idp_icon": {"type": "string"}, "idp_brand": { "type": "string", - # MSC2758-style namespaced identifier + "minLength": 1, + "maxLength": 255, + "pattern": "^[a-z][a-z0-9_.-]*$", + }, + "idp_unstable_brand": { + "type": "string", "minLength": 1, "maxLength": 255, "pattern": "^[a-z][a-z0-9_.-]*$", @@ -466,6 +471,7 @@ def _parse_oidc_config_dict( idp_name=oidc_config.get("idp_name", "OIDC"), idp_icon=idp_icon, idp_brand=oidc_config.get("idp_brand"), + unstable_idp_brand=oidc_config.get("unstable_idp_brand"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -512,6 +518,9 @@ class OidcProviderConfig: # Optional brand identifier for this IdP. idp_brand = attr.ib(type=Optional[str]) + # Optional brand identifier for the unstable API (see MSC2858). + unstable_idp_brand = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 04972f9cf0..cb67589f7d 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -83,6 +83,7 @@ class CasHandler: # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f5d1821127..01c91f9d1c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -330,6 +330,9 @@ class OidcProvider: # optional brand identifier for this auth provider self.idp_brand = provider.idp_brand + # Optional brand identifier for the unstable API (see MSC2858). + self.unstable_idp_brand = provider.unstable_idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a9645b77d8..ec2ba11c75 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -81,6 +81,7 @@ class SamlHandler(BaseHandler): # the SsoIdentityProvider protocol type. self.idp_icon = None self.idp_brand = None + self.unstable_idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 6ef459acff..415b1c2d17 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol): """Optional branding identifier""" return None + @property + def unstable_idp_brand(self) -> Optional[str]: + """Optional brand identifier for the unstable API (see MSC2858).""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 34bc1bd49b..e4c352f572 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -14,10 +14,12 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.api.urls import CLIENT_API_PREFIX from synapse.appservice import ApplicationService from synapse.handlers.sso import SsoIdentityProvider from synapse.http import get_request_uri @@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + sso_flow = { + "type": LoginRestServlet.SSO_TYPE, + "identity_providers": [ + _get_auth_flow_dict_for_idp( + idp, + ) + for idp in self._sso_handler.get_identity_providers().values() + ], + } # type: JsonDict if self._msc2858_enabled: + # backwards-compatibility support for clients which don't + # support the stable API yet sso_flow["org.matrix.msc2858.identity_providers"] = [ - _get_auth_flow_dict_for_idp(idp) + _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True) for idp in self._sso_handler.get_identity_providers().values() ] @@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet): return result -def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: +def _get_auth_flow_dict_for_idp( + idp: SsoIdentityProvider, use_unstable_brands: bool = False +) -> JsonDict: """Return an entry for the login flow dict Returns an entry suitable for inclusion in "identity_providers" in the response to GET /_matrix/client/r0/login + + Args: + idp: the identity provider to describe + use_unstable_brands: whether we should use brand identifiers suitable + for the unstable API """ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict if idp.idp_icon: e["icon"] = idp.idp_icon if idp.idp_brand: e["brand"] = idp.idp_brand + # use the stable brand identifier if the unstable identifier isn't defined. + if use_unstable_brands and idp.unstable_idp_brand: + e["brand"] = idp.unstable_idp_brand return e class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) + PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [ + re.compile( + "^" + + CLIENT_API_PREFIX + + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$" + ) + ] def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet): def register(self, http_server: HttpServer) -> None: super().register(http_server) if self._msc2858_enabled: - # expose additional endpoint for MSC2858 support + # expose additional endpoint for MSC2858 support: backwards-compat support + # for clients which don't yet support the stable endpoints. http_server.register_paths( "GET", client_patterns( diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 20af3285bd..988821b16f 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", "/_matrix/client/r0/login") self.assertEqual(channel.code, 200, channel.result) - expected_flows = [ - {"type": "m.login.cas"}, - {"type": "m.login.sso"}, - {"type": "m.login.token"}, - {"type": "m.login.password"}, - ] + ADDITIONAL_LOGIN_FLOWS + expected_flow_types = [ + "m.login.cas", + "m.login.sso", + "m.login.token", + "m.login.password", + ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS] - self.assertCountEqual(channel.json_body["flows"], expected_flows) + self.assertCountEqual( + [f["type"] for f in channel.json_body["flows"]], expected_flow_types + ) @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_get_msc2858_login_flows(self): @@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) - def test_client_idp_redirect_msc2858_disabled(self): - """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" - channel = self._make_sso_redirect_request(True, "oidc") - self.assertEqual(channel.code, 400, channel.result) - self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") - - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_unknown(self): """If the client tries to pick an unknown IdP, return a 404""" - channel = self._make_sso_redirect_request(True, "xxx") + channel = self._make_sso_redirect_request(False, "xxx") self.assertEqual(channel.code, 404, channel.result) self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") - @override_config({"experimental_features": {"msc2858_enabled": True}}) def test_client_idp_redirect_to_oidc(self): """If the client pick a known IdP, redirect to it""" + channel = self._make_sso_redirect_request(False, "oidc") + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_msc2858_redirect_to_oidc(self): + """Test the unstable API""" channel = self._make_sso_redirect_request(True, "oidc") self.assertEqual(channel.code, 302, channel.result) oidc_uri = channel.headers.getRawHeaders("Location")[0] @@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # it should redirect us to the auth page of the OIDC server self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400""" + channel = self._make_sso_redirect_request(True, "oidc") + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + def _make_sso_redirect_request( self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None ): |