diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..3ccea4b02d 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -9,6 +9,7 @@ from synapse.config import (
consent_config,
database,
emailconfig,
+ experimental,
groups,
jwt_config,
key,
@@ -48,6 +49,7 @@ def path_exists(file_path: str): ...
class RootConfig:
server: server.ServerConfig
+ experimental: experimental.ExperimentalConfig
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
new file mode 100644
index 0000000000..b1c1c51e4d
--- /dev/null
+++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config
+from synapse.types import JsonDict
+
+
+class ExperimentalConfig(Config):
+ """Config section for enabling experimental features"""
+
+ section = "experimental"
+
+ def read_config(self, config: JsonDict, **kwargs):
+ experimental = config.get("experimental_features") or {}
+
+ # MSC2858 (multiple SSO identity providers)
+ self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4bd2b3587b..64a2429f77 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -24,6 +24,7 @@ from .cas import CasConfig
from .consent_config import ConsentConfig
from .database import DatabaseConfig
from .emailconfig import EmailConfig
+from .experimental import ExperimentalConfig
from .federation import FederationConfig
from .groups import GroupsConfig
from .jwt_config import JWTConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
+ ExperimentalConfig,
TlsConfig,
FederationConfig,
CacheConfig,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d493327a10..afc1341d09 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
@@ -235,7 +235,10 @@ class SsoHandler:
respond_with_html(request, code, html)
async def handle_redirect_request(
- self, request: SynapseRequest, client_redirect_url: bytes,
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ idp_id: Optional[str],
) -> str:
"""Handle a request to /login/sso/redirect
@@ -243,6 +246,7 @@ class SsoHandler:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
+ idp_id: optional identity provider chosen by the client
Returns:
the URI to redirect to
@@ -252,10 +256,19 @@ class SsoHandler:
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)
+ # if the client chose an IdP, use that
+ idp = None # type: Optional[SsoIdentityProvider]
+ if idp_id:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ raise NotFoundError("Unknown identity provider")
+
# if we only have one auth provider, redirect to it directly
- if len(self._identity_providers) == 1:
- ap = next(iter(self._identity_providers.values()))
- return await ap.handle_redirect_request(request, client_redirect_url)
+ elif len(self._identity_providers) == 1:
+ idp = next(iter(self._identity_providers.values()))
+
+ if idp:
+ return await idp.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..d69d579b3a 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Pattern,
+ Tuple,
+ Union,
+)
import jinja2
from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -168,11 +180,25 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler)
-class HttpServer:
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+ ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
+
+
+class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server
"""
- def register_paths(self, method, path_patterns, callback):
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
""" Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
@@ -180,12 +206,14 @@ class HttpServer:
an unpacked tuple.
Args:
- method (str): The method to listen to.
- path_patterns (list<SRE_Pattern>): The regex used to match requests.
- callback (function): The function to fire if we receive a matched
+ method: The HTTP method to listen to.
+ path_patterns: The regex used to match requests.
+ callback: The function to fire if we receive a matched
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
- This should return a tuple of (code, response).
+ This should return either tuple of (code, response), or None.
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
"""
pass
@@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request(
self, request: SynapseRequest
- ) -> Tuple[Callable, str, Dict[str, str]]:
+ ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..0a561eea60 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+
+ if self._msc2858_enabled:
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet):
return result
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ return e
+
+
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
- request, client_redirect_url
+ request, client_redirect_url, idp_id,
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
|