summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-16 11:21:26 +0000
committerGitHub <noreply@github.com>2021-03-16 11:21:26 +0000
commitdd69110d9588b5fc8cca10cba9509d80f88b84f4 (patch)
treed425c7b5b9b781ca57bfaf4c987e3e01e061897e /synapse/rest/client
parentClean up config settings for stats (#9604) (diff)
downloadsynapse-dd69110d9588b5fc8cca10cba9509d80f88b84f4.tar.xz
Add support for stable MSC2858 API (#9617)
The stable format uses different brand identifiers, so we need to support two
identifiers for each IdP.
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/login.py39
1 files changed, 34 insertions, 5 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 34bc1bd49b..e4c352f572 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+import re
 from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
+from synapse.api.urls import CLIENT_API_PREFIX
 from synapse.appservice import ApplicationService
 from synapse.handlers.sso import SsoIdentityProvider
 from synapse.http import get_request_uri
@@ -94,11 +96,21 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+            sso_flow = {
+                "type": LoginRestServlet.SSO_TYPE,
+                "identity_providers": [
+                    _get_auth_flow_dict_for_idp(
+                        idp,
+                    )
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ],
+            }  # type: JsonDict
 
             if self._msc2858_enabled:
+                # backwards-compatibility support for clients which don't
+                # support the stable API yet
                 sso_flow["org.matrix.msc2858.identity_providers"] = [
-                    _get_auth_flow_dict_for_idp(idp)
+                    _get_auth_flow_dict_for_idp(idp, use_unstable_brands=True)
                     for idp in self._sso_handler.get_identity_providers().values()
                 ]
 
@@ -331,22 +343,38 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+def _get_auth_flow_dict_for_idp(
+    idp: SsoIdentityProvider, use_unstable_brands: bool = False
+) -> JsonDict:
     """Return an entry for the login flow dict
 
     Returns an entry suitable for inclusion in "identity_providers" in the
     response to GET /_matrix/client/r0/login
+
+    Args:
+        idp: the identity provider to describe
+        use_unstable_brands: whether we should use brand identifiers suitable
+           for the unstable API
     """
     e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
     if idp.idp_icon:
         e["icon"] = idp.idp_icon
     if idp.idp_brand:
         e["brand"] = idp.idp_brand
+    # use the stable brand identifier if the unstable identifier isn't defined.
+    if use_unstable_brands and idp.unstable_idp_brand:
+        e["brand"] = idp.unstable_idp_brand
     return e
 
 
 class SsoRedirectServlet(RestServlet):
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+    PATTERNS = list(client_patterns("/login/(cas|sso)/redirect$", v1=True)) + [
+        re.compile(
+            "^"
+            + CLIENT_API_PREFIX
+            + "/r0/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$"
+        )
+    ]
 
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
@@ -364,7 +392,8 @@ class SsoRedirectServlet(RestServlet):
     def register(self, http_server: HttpServer) -> None:
         super().register(http_server)
         if self._msc2858_enabled:
-            # expose additional endpoint for MSC2858 support
+            # expose additional endpoint for MSC2858 support: backwards-compat support
+            # for clients which don't yet support the stable endpoints.
             http_server.register_paths(
                 "GET",
                 client_patterns(