diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 7f5e449eb2..2bfb537c15 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -237,7 +237,7 @@ class OIDCConfig(Config):
#
#- idp_id: github
# idp_name: Github
- # idp_brand: org.matrix.github
+ # idp_brand: github
# discover: false
# issuer: "https://github.com/"
# client_id: "your-client-id" # TO BE FILLED
@@ -272,7 +272,12 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"idp_icon": {"type": "string"},
"idp_brand": {
"type": "string",
- # MSC2758-style namespaced identifier
+ "minLength": 1,
+ "maxLength": 255,
+ "pattern": "^[a-z][a-z0-9_.-]*$",
+ },
+ "idp_unstable_brand": {
+ "type": "string",
"minLength": 1,
"maxLength": 255,
"pattern": "^[a-z][a-z0-9_.-]*$",
@@ -466,6 +471,7 @@ def _parse_oidc_config_dict(
idp_name=oidc_config.get("idp_name", "OIDC"),
idp_icon=idp_icon,
idp_brand=oidc_config.get("idp_brand"),
+ unstable_idp_brand=oidc_config.get("unstable_idp_brand"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
@@ -512,6 +518,9 @@ class OidcProviderConfig:
# Optional brand identifier for this IdP.
idp_brand = attr.ib(type=Optional[str])
+ # Optional brand identifier for the unstable API (see MSC2858).
+ unstable_idp_brand = attr.ib(type=Optional[str])
+
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 04972f9cf0..cb67589f7d 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -83,6 +83,7 @@ class CasHandler:
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
self._sso_handler = hs.get_sso_handler()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f5d1821127..01c91f9d1c 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -330,6 +330,9 @@ class OidcProvider:
# optional brand identifier for this auth provider
self.idp_brand = provider.idp_brand
+ # Optional brand identifier for the unstable API (see MSC2858).
+ self.unstable_idp_brand = provider.unstable_idp_brand
+
self._sso_handler = hs.get_sso_handler()
self._sso_handler.register_identity_provider(self)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index a9645b77d8..ec2ba11c75 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -81,6 +81,7 @@ class SamlHandler(BaseHandler):
# the SsoIdentityProvider protocol type.
self.idp_icon = None
self.idp_brand = None
+ self.unstable_idp_brand = None
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 6ef459acff..415b1c2d17 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -98,6 +98,11 @@ class SsoIdentityProvider(Protocol):
"""Optional branding identifier"""
return None
+ @property
+ def unstable_idp_brand(self) -> Optional[str]:
+ """Optional brand identifier for the unstable API (see MSC2858)."""
+ return None
+
@abc.abstractmethod
async def handle_redirect_request(
self,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 34bc1bd49b..e4c352f572 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
from synapse.handlers.sso import SsoIdentityProvider
from synapse.http import get_request_uri
@@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+ sso_flow = {
+ "type": LoginRestServlet.SSO_TYPE,
+ "identity_providers": [
+ _get_auth_flow_dict_for_idp(
+ idp,
+ )
+ for idp in self._sso_handler.get_identity_providers().values()
+ ],
+ } # type: JsonDict
if self._msc2858_enabled:
+ # backwards-compatibility support for clients which don't
+ # support the stable API yet
sso_flow["org.matrix.msc2858.identity_providers"] = [
- _get_auth_flow_dict_for_idp(idp)
+ _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
for idp in self._sso_handler.get_identity_providers().values()
]
@@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
return result
-def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+def _get_auth_flow_dict_for_idp(
+ idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
"""Return an entry for the login flow dict
Returns an entry suitable for inclusion in "identity_providers" in the
response to GET /_matrix/client/r0/login
+
+ Args:
+ idp: the identity provider to describe
+ use_unstable_brands: whether we should use brand identifiers suitable
+ for the unstable API
"""
e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
if idp.idp_icon:
e["icon"] = idp.idp_icon
if idp.idp_brand:
e["brand"] = idp.idp_brand
+ # use the stable brand identifier if the unstable identifier isn't defined.
+ if use_unstable_brands and idp.unstable_idp_brand:
+ e["brand"] = idp.unstable_idp_brand
return e
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+ PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+ re.compile(
+ "^"
+ + CLIENT_API_PREFIX
+ + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+ )
+ ]
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
def register(self, http_server: HttpServer) -> None:
super().register(http_server)
if self._msc2858_enabled:
- # expose additional endpoint for MSC2858 support
+ # expose additional endpoint for MSC2858 support: backwards-compat support
+ # for clients which don't yet support the stable endpoints.
http_server.register_paths(
"GET",
client_patterns(
|