summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9617.feature1
-rw-r--r--docs/openid.md8
-rw-r--r--docs/sample_config.yaml2
-rw-r--r--synapse/config/oidc_config.py13
-rw-r--r--synapse/handlers/cas_handler.py1
-rw-r--r--synapse/handlers/oidc_handler.py3
-rw-r--r--synapse/handlers/saml_handler.py1
-rw-r--r--synapse/handlers/sso.py5
-rw-r--r--synapse/rest/client/v1/login.py39
-rw-r--r--tests/rest/client/v1/test_login.py43
10 files changed, 88 insertions, 28 deletions
diff --git a/changelog.d/9617.feature b/changelog.d/9617.feature
new file mode 100644
index 0000000000..b462a32b92
--- /dev/null
+++ b/changelog.d/9617.feature
@@ -0,0 +1 @@
+Finalise support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)).
diff --git a/docs/openid.md b/docs/openid.md
index 01205d1220..cfaafc5015 100644
--- a/docs/openid.md
+++ b/docs/openid.md
@@ -226,7 +226,7 @@ Synapse config:
 oidc_providers:
   - idp_id: github
     idp_name: Github
-    idp_brand: "org.matrix.github"  # optional: styling hint for clients
+    idp_brand: "github"  # optional: styling hint for clients
     discover: false
     issuer: "https://github.com/"
     client_id: "your-client-id" # TO BE FILLED
@@ -252,7 +252,7 @@ oidc_providers:
    oidc_providers:
      - idp_id: google
        idp_name: Google
-       idp_brand: "org.matrix.google"  # optional: styling hint for clients
+       idp_brand: "google"  # optional: styling hint for clients
        issuer: "https://accounts.google.com/"
        client_id: "your-client-id" # TO BE FILLED
        client_secret: "your-client-secret" # TO BE FILLED
@@ -299,7 +299,7 @@ Synapse config:
 oidc_providers:
   - idp_id: gitlab
     idp_name: Gitlab
-    idp_brand: "org.matrix.gitlab"  # optional: styling hint for clients
+    idp_brand: "gitlab"  # optional: styling hint for clients
     issuer: "https://gitlab.com/"
     client_id: "your-client-id" # TO BE FILLED
     client_secret: "your-client-secret" # TO BE FILLED
@@ -334,7 +334,7 @@ Synapse config:
 ```yaml
   - idp_id: facebook
     idp_name: Facebook
-    idp_brand: "org.matrix.facebook"  # optional: styling hint for clients
+    idp_brand: "facebook"  # optional: styling hint for clients
     discover: false
     issuer: "https://facebook.com"
     client_id: "your-client-id" # TO BE FILLED
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 41ab35595b..7de000f4a4 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -1919,7 +1919,7 @@ oidc_providers:
   #
   #- idp_id: github
   #  idp_name: Github
-  #  idp_brand: org.matrix.github
+  #  idp_brand: github
   #  discover: false
   #  issuer: "https://github.com/"
   #  client_id: "your-client-id" # TO BE FILLED
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 7f5e449eb2..2bfb537c15 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -237,7 +237,7 @@ class OIDCConfig(Config):
           #
           #- idp_id: github
           #  idp_name: Github
-          #  idp_brand: org.matrix.github
+          #  idp_brand: github
           #  discover: false
           #  issuer: "https://github.com/"
           #  client_id: "your-client-id" # TO BE FILLED
@@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
         "idp_icon": {"type": "string"},
         "idp_brand": {
             "type": "string",
-            # MSC2758-style namespaced identifier
+            "minLength": 1,
+            "maxLength": 255,
+            "pattern": "^[a-z][a-z0-9_.-]*$",
+        },
+        "idp_unstable_brand": {
+            "type": "string",
             "minLength": 1,
             "maxLength": 255,
             "pattern": "^[a-z][a-z0-9_.-]*$",
@@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
         idp_name=oidc_config.get("idp_name", "OIDC"),
         idp_icon=idp_icon,
         idp_brand=oidc_config.get("idp_brand"),
+        unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
@@ -512,6 +518,9 @@ class OidcProviderConfig:
     # Optional brand identifier for this IdP.
     idp_brand = attr.ib(type=Optional[str])
 
+    # Optional brand identifier for the unstable API (see MSC2858).
+    unstable_idp_brand = attr.ib(type=Optional[str])
+
     # whether the OIDC discovery mechanism is used to discover endpoints
     discover = attr.ib(type=bool)
 
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 04972f9cf0..cb67589f7d 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -83,6 +83,7 @@ class CasHandler:
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
         self.idp_brand = None
+        self.unstable_idp_brand = None
 
         self._sso_handler = hs.get_sso_handler()
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f5d1821127..01c91f9d1c 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -330,6 +330,9 @@ class OidcProvider:
         # optional brand identifier for this auth provider
         self.idp_brand = provider.idp_brand
 
+        # Optional brand identifier for the unstable API (see MSC2858).
+        self.unstable_idp_brand = provider.unstable_idp_brand
+
         self._sso_handler = hs.get_sso_handler()
 
         self._sso_handler.register_identity_provider(self)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a9645b77d8..ec2ba11c75 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
         # the SsoIdentityProvider protocol type.
         self.idp_icon = None
         self.idp_brand = None
+        self.unstable_idp_brand = None
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 6ef459acff..415b1c2d17 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
         """Optional branding identifier"""
         return None
 
+    @property
+    def unstable_idp_brand(self) -> Optional[str]:
+        """Optional brand identifier for the unstable API (see MSC2858)."""
+        return None
+
     @abc.abstractmethod
     async def handle_redirect_request(
         self,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 34bc1bd49b..e4c352f572 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+import re
 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
 from synapse.handlers.sso import SsoIdentityProvider
 from synapse.http import get_request_uri
@@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+            sso_flow = {
+                "type": LoginRestServlet.SSO_TYPE,
+                "identity_providers": [
+                    _get_auth_flow_dict_for_idp(
+                        idp,
+                    )
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ],
+            }  # type: JsonDict
 
             if self._msc2858_enabled:
+                # backwards-compatibility support for clients which don't
+                # support the stable API yet
                 sso_flow["org.matrix.msc2858.identity_providers"] = [
-                    _get_auth_flow_dict_for_idp(idp)
+                    _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
                     for idp in self._sso_handler.get_identity_providers().values()
                 ]
 
@@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+def _get_auth_flow_dict_for_idp(
+    idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
     """Return an entry for the login flow dict
 
     Returns an entry suitable for inclusion in "identity_providers" in the
     response to GET /_matrix/client/r0/login
+
+    Args:
+        idp: the identity provider to describe
+        use_unstable_brands: whether we should use brand identifiers suitable
+           for the unstable API
     """
     e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
     if idp.idp_icon:
         e["icon"] = idp.idp_icon
     if idp.idp_brand:
         e["brand"] = idp.idp_brand
+    # use the stable brand identifier if the unstable identifier isn't defined.
+    if use_unstable_brands and idp.unstable_idp_brand:
+        e["brand"] = idp.unstable_idp_brand
     return e
 
 
 class SsoRedirectServlet(RestServlet):
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+    PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+        re.compile(
+            "^"
+            + CLIENT_API_PREFIX
+            + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+        )
+    ]
 
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
@@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
     def register(self, http_server: HttpServer) -> None:
         super().register(http_server)
         if self._msc2858_enabled:
-            # expose additional endpoint for MSC2858 support
+            # expose additional endpoint for MSC2858 support: backwards-compat support
+            # for clients which don't yet support the stable endpoints.
             http_server.register_paths(
                 "GET",
                 client_patterns(
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 20af3285bd..988821b16f 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -437,14 +437,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
 
-        expected_flows = [
-            {"type": "m.login.cas"},
-            {"type": "m.login.sso"},
-            {"type": "m.login.token"},
-            {"type": "m.login.password"},
-        ] + ADDITIONAL_LOGIN_FLOWS
+        expected_flow_types = [
+            "m.login.cas",
+            "m.login.sso",
+            "m.login.token",
+            "m.login.password",
+        ] + [f["type"] for f in ADDITIONAL_LOGIN_FLOWS]
 
-        self.assertCountEqual(channel.json_body["flows"], expected_flows)
+        self.assertCountEqual(
+            [f["type"] for f in channel.json_body["flows"]], expected_flow_types
+        )
 
     @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_get_msc2858_login_flows(self):
@@ -636,22 +638,25 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(channel.code, 400, channel.result)
 
-    def test_client_idp_redirect_msc2858_disabled(self):
-        """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
-        channel = self._make_sso_redirect_request(True, "oidc")
-        self.assertEqual(channel.code, 400, channel.result)
-        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
-
-    @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_unknown(self):
         """If the client tries to pick an unknown IdP, return a 404"""
-        channel = self._make_sso_redirect_request(True, "xxx")
+        channel = self._make_sso_redirect_request(False, "xxx")
         self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
-    @override_config({"experimental_features": {"msc2858_enabled": True}})
     def test_client_idp_redirect_to_oidc(self):
         """If the client pick a known IdP, redirect to it"""
+        channel = self._make_sso_redirect_request(False, "oidc")
+        self.assertEqual(channel.code, 302, channel.result)
+        oidc_uri = channel.headers.getRawHeaders("Location")[0]
+        oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+        # it should redirect us to the auth page of the OIDC server
+        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
+    @override_config({"experimental_features": {"msc2858_enabled": True}})
+    def test_client_msc2858_redirect_to_oidc(self):
+        """Test the unstable API"""
         channel = self._make_sso_redirect_request(True, "oidc")
         self.assertEqual(channel.code, 302, channel.result)
         oidc_uri = channel.headers.getRawHeaders("Location")[0]
@@ -660,6 +665,12 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         # it should redirect us to the auth page of the OIDC server
         self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
 
+    def test_client_idp_redirect_msc2858_disabled(self):
+        """If the client tries to use the MSC2858 endpoint but MSC2858 is disabled, return a 400"""
+        channel = self._make_sso_redirect_request(True, "oidc")
+        self.assertEqual(channel.code, 400, channel.result)
+        self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
     def _make_sso_redirect_request(
         self, unstable_endpoint: bool = False, idp_prov: Optional[str] = None
     ):