summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/room_versions.py32
-rw-r--r--synapse/app/homeserver.py2
-rw-r--r--synapse/config/sso.py27
-rw-r--r--synapse/events/utils.py16
-rw-r--r--synapse/handlers/cas_handler.py38
-rw-r--r--synapse/handlers/oidc_handler.py18
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/saml_handler.py28
-rw-r--r--synapse/handlers/sso.py100
-rw-r--r--synapse/logging/context.py50
-rw-r--r--synapse/res/templates/sso_login_idp_picker.html28
-rw-r--r--synapse/rest/client/v1/login.py89
-rw-r--r--synapse/rest/client/v2_alpha/auth.py34
-rw-r--r--synapse/rest/synapse/client/pick_idp.py82
-rw-r--r--synapse/static/client/login/style.css5
-rw-r--r--synapse/storage/database.py8
-rw-r--r--synapse/util/metrics.py10
18 files changed, 427 insertions, 146 deletions
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f3ecbf36b6..de2cc15d33 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -51,11 +51,11 @@ class RoomDisposition:
 class RoomVersion:
     """An object which describes the unique attributes of a room version."""
 
-    identifier = attr.ib()  # str; the identifier for this version
-    disposition = attr.ib()  # str; one of the RoomDispositions
-    event_format = attr.ib()  # int; one of the EventFormatVersions
-    state_res = attr.ib()  # int; one of the StateResolutionVersions
-    enforce_key_validity = attr.ib()  # bool
+    identifier = attr.ib(type=str)  # the identifier for this version
+    disposition = attr.ib(type=str)  # one of the RoomDispositions
+    event_format = attr.ib(type=int)  # one of the EventFormatVersions
+    state_res = attr.ib(type=int)  # one of the StateResolutionVersions
+    enforce_key_validity = attr.ib(type=bool)
 
     # bool: before MSC2261/MSC2432, m.room.aliases had special auth rules and redaction rules
     special_case_aliases_auth = attr.ib(type=bool)
@@ -64,9 +64,11 @@ class RoomVersion:
     # * Floats
     # * NaN, Infinity, -Infinity
     strict_canonicaljson = attr.ib(type=bool)
-    # bool: MSC2209: Check 'notifications' key while verifying
+    # MSC2209: Check 'notifications' key while verifying
     # m.room.power_levels auth rules.
     limit_notifications_power_levels = attr.ib(type=bool)
+    # MSC2174/MSC2176: Apply updated redaction rules algorithm.
+    msc2176_redaction_rules = attr.ib(type=bool)
 
 
 class RoomVersions:
@@ -79,6 +81,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V2 = RoomVersion(
         "2",
@@ -89,6 +92,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V3 = RoomVersion(
         "3",
@@ -99,6 +103,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V4 = RoomVersion(
         "4",
@@ -109,6 +114,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V5 = RoomVersion(
         "5",
@@ -119,6 +125,7 @@ class RoomVersions:
         special_case_aliases_auth=True,
         strict_canonicaljson=False,
         limit_notifications_power_levels=False,
+        msc2176_redaction_rules=False,
     )
     V6 = RoomVersion(
         "6",
@@ -129,6 +136,18 @@ class RoomVersions:
         special_case_aliases_auth=False,
         strict_canonicaljson=True,
         limit_notifications_power_levels=True,
+        msc2176_redaction_rules=False,
+    )
+    MSC2176 = RoomVersion(
+        "org.matrix.msc2176",
+        RoomDisposition.UNSTABLE,
+        EventFormatVersions.V3,
+        StateResolutionVersions.V2,
+        enforce_key_validity=True,
+        special_case_aliases_auth=False,
+        strict_canonicaljson=True,
+        limit_notifications_power_levels=True,
+        msc2176_redaction_rules=True,
     )
 
 
@@ -141,5 +160,6 @@ KNOWN_ROOM_VERSIONS = {
         RoomVersions.V4,
         RoomVersions.V5,
         RoomVersions.V6,
+        RoomVersions.MSC2176,
     )
 }  # type: Dict[str, RoomVersion]
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 8d9b53be53..b1d9817a6a 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
 from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
 from synapse.rest.synapse.client.pick_username import pick_username_resource
 from synapse.rest.well_known import WellKnownResource
 from synapse.server import HomeServer
@@ -194,6 +195,7 @@ class SynapseHomeServer(HomeServer):
                     "/.well-known/matrix/client": WellKnownResource(self),
                     "/_synapse/admin": AdminRestResource(self),
                     "/_synapse/client/pick_username": pick_username_resource(self),
+                    "/_synapse/client/pick_idp": PickIdpResource(self),
                 }
             )
 
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 93bbd40937..1aeb1c5c92 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -31,6 +31,7 @@ class SSOConfig(Config):
 
         # Read templates from disk
         (
+            self.sso_login_idp_picker_template,
             self.sso_redirect_confirm_template,
             self.sso_auth_confirm_template,
             self.sso_error_template,
@@ -38,6 +39,7 @@ class SSOConfig(Config):
             sso_auth_success_template,
         ) = self.read_templates(
             [
+                "sso_login_idp_picker.html",
                 "sso_redirect_confirm.html",
                 "sso_auth_confirm.html",
                 "sso_error.html",
@@ -98,6 +100,31 @@ class SSOConfig(Config):
             #
             # Synapse will look for the following templates in this directory:
             #
+            # * HTML page to prompt the user to choose an Identity Provider during
+            #   login: 'sso_login_idp_picker.html'.
+            #
+            #   This is only used if multiple SSO Identity Providers are configured.
+            #
+            #   When rendering, this template is given the following variables:
+            #     * redirect_url: the URL that the user will be redirected to after
+            #       login. Needs manual escaping (see
+            #       https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping).
+            #
+            #     * server_name: the homeserver's name.
+            #
+            #     * providers: a list of available Identity Providers. Each element is
+            #       an object with the following attributes:
+            #         * idp_id: unique identifier for the IdP
+            #         * idp_name: user-facing name for the IdP
+            #
+            #   The rendered HTML page should contain a form which submits its results
+            #   back as a GET request, with the following query parameters:
+            #
+            #     * redirectUrl: the client redirect URI (ie, the `redirect_url` passed
+            #       to the template)
+            #
+            #     * idp: the 'idp_id' of the chosen IDP.
+            #
             # * HTML page for a confirmation step before redirecting back to the client
             #   with the login token: 'sso_redirect_confirm.html'.
             #
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 14f7f1156f..9c22e33813 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -79,13 +79,15 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
         "state_key",
         "depth",
         "prev_events",
-        "prev_state",
         "auth_events",
         "origin",
         "origin_server_ts",
-        "membership",
     ]
 
+    # Room versions from before MSC2176 had additional allowed keys.
+    if not room_version.msc2176_redaction_rules:
+        allowed_keys.extend(["prev_state", "membership"])
+
     event_type = event_dict["type"]
 
     new_content = {}
@@ -98,6 +100,10 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     if event_type == EventTypes.Member:
         add_fields("membership")
     elif event_type == EventTypes.Create:
+        # MSC2176 rules state that create events cannot be redacted.
+        if room_version.msc2176_redaction_rules:
+            return event_dict
+
         add_fields("creator")
     elif event_type == EventTypes.JoinRules:
         add_fields("join_rule")
@@ -112,10 +118,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
             "kick",
             "redact",
         )
+
+        if room_version.msc2176_redaction_rules:
+            add_fields("invite")
+
     elif event_type == EventTypes.Aliases and room_version.special_case_aliases_auth:
         add_fields("aliases")
     elif event_type == EventTypes.RoomHistoryVisibility:
         add_fields("history_visibility")
+    elif event_type == EventTypes.Redaction and room_version.msc2176_redaction_rules:
+        add_fields("redacts")
 
     allowed_fields = {k: v for k, v in event_dict.items() if k in allowed_keys}
 
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index fca210a5a6..f3430c6713 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -75,10 +75,15 @@ class CasHandler:
         self._http_client = hs.get_proxied_http_client()
 
         # identifier for the external_ids table
-        self._auth_provider_id = "cas"
+        self.idp_id = "cas"
+
+        # user-facing name of this auth provider
+        self.idp_name = "CAS"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _build_service_param(self, args: Dict[str, str]) -> str:
         """
         Generates a value to use as the "service" parameter when redirecting or
@@ -105,7 +110,7 @@ class CasHandler:
         Args:
             ticket: The CAS ticket from the client.
             service_args: Additional arguments to include in the service URL.
-                Should be the same as those passed to `get_redirect_url`.
+                Should be the same as those passed to `handle_redirect_request`.
 
         Raises:
             CasError: If there's an error parsing the CAS response.
@@ -184,16 +189,31 @@ class CasHandler:
 
         return CasResponse(user, attributes)
 
-    def get_redirect_url(self, service_args: Dict[str, str]) -> str:
-        """
-        Generates a URL for the CAS server where the client should be redirected.
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Generates a URL for the CAS server where the client should be redirected.
 
         Args:
-            service_args: Additional arguments to include in the final redirect URL.
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
 
         Returns:
-            The URL to redirect the client to.
+            URL to redirect to
         """
+
+        if ui_auth_session_id:
+            service_args = {"session": ui_auth_session_id}
+        else:
+            assert client_redirect_url
+            service_args = {"redirectUrl": client_redirect_url.decode("utf8")}
+
         args = urllib.parse.urlencode(
             {"service": self._build_service_param(service_args)}
         )
@@ -275,7 +295,7 @@ class CasHandler:
         # first check if we're doing a UIA
         if session:
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, cas_response.username, session, request,
+                self.idp_id, cas_response.username, session, request,
             )
 
         # otherwise, we're handling a login request.
@@ -375,7 +395,7 @@ class CasHandler:
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             cas_response.username,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 709f8dfc13..6835c6c462 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -119,10 +119,15 @@ class OidcHandler(BaseHandler):
         self._macaroon_secret_key = hs.config.macaroon_secret_key
 
         # identifier for the external_ids table
-        self._auth_provider_id = "oidc"
+        self.idp_id = "oidc"
+
+        # user-facing name of this auth provider
+        self.idp_name = "OIDC"
 
         self._sso_handler = hs.get_sso_handler()
 
+        self._sso_handler.register_identity_provider(self)
+
     def _validate_metadata(self):
         """Verifies the provider metadata.
 
@@ -475,7 +480,7 @@ class OidcHandler(BaseHandler):
     async def handle_redirect_request(
         self,
         request: SynapseRequest,
-        client_redirect_url: bytes,
+        client_redirect_url: Optional[bytes],
         ui_auth_session_id: Optional[str] = None,
     ) -> str:
         """Handle an incoming request to /login/sso/redirect
@@ -499,7 +504,7 @@ class OidcHandler(BaseHandler):
             request: the incoming request from the browser.
                 We'll respond to it with a redirect and a cookie.
             client_redirect_url: the URL that we should redirect the client to
-                when everything is done
+                when everything is done (or None for UI Auth)
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
@@ -511,6 +516,9 @@ class OidcHandler(BaseHandler):
         state = generate_token()
         nonce = generate_token()
 
+        if not client_redirect_url:
+            client_redirect_url = b""
+
         cookie = self._generate_oidc_session_token(
             state=state,
             nonce=nonce,
@@ -682,7 +690,7 @@ class OidcHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+                self.idp_id, remote_user_id, ui_auth_session_id, request
             )
 
         # otherwise, it's a login
@@ -923,7 +931,7 @@ class OidcHandler(BaseHandler):
             extra_attributes = await get_extra_attributes(userinfo, token)
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index dee0ef45e7..36f9ee4b71 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -156,7 +156,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["displayname"]
+            return result.get("displayname")
 
     async def set_displayname(
         self,
@@ -246,7 +246,7 @@ class ProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-            return result["avatar_url"]
+            return result.get("avatar_url")
 
     async def set_avatar_url(
         self,
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 1f809fa161..3bece6d668 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -365,7 +365,7 @@ class RoomCreationHandler(BaseHandler):
         creation_content = {
             "room_version": new_room_version.identifier,
             "predecessor": {"room_id": old_room_id, "event_id": tombstone_event_id},
-        }
+        }  # type: JsonDict
 
         # Check if old room was non-federatable
 
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 5fa7ab3f8b..a8376543c9 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -73,27 +73,41 @@ class SamlHandler(BaseHandler):
         )
 
         # identifier for the external_ids table
-        self._auth_provider_id = "saml"
+        self.idp_id = "saml"
+
+        # user-facing name of this auth provider
+        self.idp_name = "SAML"
 
         # a map from saml session id to Saml2SessionData object
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         self._sso_handler = hs.get_sso_handler()
+        self._sso_handler.register_identity_provider(self)
 
-    def handle_redirect_request(
-        self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
-    ) -> bytes:
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
         """Handle an incoming request to /login/sso/redirect
 
         Args:
+            request: the incoming HTTP request
             client_redirect_url: the URL that we should redirect the
-                client to when everything is done
+                client to after login (or None for UI Auth).
             ui_auth_session_id: The session ID of the ongoing UI Auth (or
                 None if this is a login).
 
         Returns:
             URL to redirect to
         """
+        if not client_redirect_url:
+            # Some SAML identity providers (e.g. Google) require a
+            # RelayState parameter on requests, so pass in a dummy redirect URL
+            # (which will never get used).
+            client_redirect_url = b"unused"
+
         reqid, info = self._saml_client.prepare_for_authenticate(
             entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
@@ -210,7 +224,7 @@ class SamlHandler(BaseHandler):
                 return
 
             return await self._sso_handler.complete_sso_ui_auth_request(
-                self._auth_provider_id,
+                self.idp_id,
                 remote_user_id,
                 current_session.ui_auth_session_id,
                 request,
@@ -306,7 +320,7 @@ class SamlHandler(BaseHandler):
             return None
 
         await self._sso_handler.complete_sso_login_request(
-            self._auth_provider_id,
+            self.idp_id,
             remote_user_id,
             request,
             client_redirect_url,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 33cd6bc178..2da1ea2223 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,15 +12,17 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import logging
-from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional
+from urllib.parse import urlencode
 
 import attr
-from typing_extensions import NoReturn
+from typing_extensions import NoReturn, Protocol
 
 from twisted.web.http import Request
 
-from synapse.api.errors import RedirectException, SynapseError
+from synapse.api.errors import Codes, RedirectException, SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters
@@ -40,6 +42,58 @@ class MappingException(Exception):
     """
 
 
+class SsoIdentityProvider(Protocol):
+    """Abstract base class to be implemented by SSO Identity Providers
+
+    An Identity Provider, or IdP, is an external HTTP service which authenticates a user
+    to say whether they should be allowed to log in, or perform a given action.
+
+    Synapse supports various implementations of IdPs, including OpenID Connect, SAML,
+    and CAS.
+
+    The main entry point is `handle_redirect_request`, which should return a URI to
+    redirect the user's browser to the IdP's authentication page.
+
+    Each IdP should be registered with the SsoHandler via
+    `hs.get_sso_handler().register_identity_provider()`, so that requests to
+    `/_matrix/client/r0/login/sso/redirect` can be correctly dispatched.
+    """
+
+    @property
+    @abc.abstractmethod
+    def idp_id(self) -> str:
+        """A unique identifier for this SSO provider
+
+        Eg, "saml", "cas", "github"
+        """
+
+    @property
+    @abc.abstractmethod
+    def idp_name(self) -> str:
+        """User-facing name for this provider"""
+
+    @abc.abstractmethod
+    async def handle_redirect_request(
+        self,
+        request: SynapseRequest,
+        client_redirect_url: Optional[bytes],
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        """Handle an incoming request to /login/sso/redirect
+
+        Args:
+            request: the incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login (or None for UI Auth).
+            ui_auth_session_id: The session ID of the ongoing UI Auth (or
+                None if this is a login).
+
+        Returns:
+            URL to redirect to
+        """
+        raise NotImplementedError()
+
+
 @attr.s
 class UserAttributes:
     # the localpart of the mxid that the mapper has assigned to the user.
@@ -100,6 +154,18 @@ class SsoHandler:
         # a map from session id to session data
         self._username_mapping_sessions = {}  # type: Dict[str, UsernameMappingSession]
 
+        # map from idp_id to SsoIdentityProvider
+        self._identity_providers = {}  # type: Dict[str, SsoIdentityProvider]
+
+    def register_identity_provider(self, p: SsoIdentityProvider):
+        p_id = p.idp_id
+        assert p_id not in self._identity_providers
+        self._identity_providers[p_id] = p
+
+    def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
+        """Get the configured identity providers"""
+        return self._identity_providers
+
     def render_error(
         self,
         request: Request,
@@ -124,6 +190,34 @@ class SsoHandler:
         )
         respond_with_html(request, code, html)
 
+    async def handle_redirect_request(
+        self, request: SynapseRequest, client_redirect_url: bytes,
+    ) -> str:
+        """Handle a request to /login/sso/redirect
+
+        Args:
+            request: incoming HTTP request
+            client_redirect_url: the URL that we should redirect the
+                client to after login.
+
+        Returns:
+             the URI to redirect to
+        """
+        if not self._identity_providers:
+            raise SynapseError(
+                400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
+            )
+
+        # if we only have one auth provider, redirect to it directly
+        if len(self._identity_providers) == 1:
+            ap = next(iter(self._identity_providers.values()))
+            return await ap.handle_redirect_request(request, client_redirect_url)
+
+        # otherwise, redirect to the IDP picker
+        return "/_synapse/client/pick_idp?" + urlencode(
+            (("redirectUrl", client_redirect_url),)
+        )
+
     async def get_sso_user_by_remote_user_id(
         self, auth_provider_id: str, remote_user_id: str
     ) -> Optional[str]:
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index a507a83e93..c2db8b45f3 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -252,7 +252,12 @@ class LoggingContext:
         "scope",
     ]
 
-    def __init__(self, name=None, parent_context=None, request=None) -> None:
+    def __init__(
+        self,
+        name: Optional[str] = None,
+        parent_context: "Optional[LoggingContext]" = None,
+        request: Optional[str] = None,
+    ) -> None:
         self.previous_context = current_context()
         self.name = name
 
@@ -536,20 +541,20 @@ class LoggingContextFilter(logging.Filter):
     def __init__(self, request: str = ""):
         self._default_request = request
 
-    def filter(self, record) -> Literal[True]:
+    def filter(self, record: logging.LogRecord) -> Literal[True]:
         """Add each fields from the logging contexts to the record.
         Returns:
             True to include the record in the log output.
         """
         context = current_context()
-        record.request = self._default_request
+        record.request = self._default_request  # type: ignore
 
         # context should never be None, but if it somehow ends up being, then
         # we end up in a death spiral of infinite loops, so let's check, for
         # robustness' sake.
         if context is not None:
             # Logging is interested in the request.
-            record.request = context.request
+            record.request = context.request  # type: ignore
 
         return True
 
@@ -616,9 +621,7 @@ def set_current_context(context: LoggingContextOrSentinel) -> LoggingContextOrSe
     return current
 
 
-def nested_logging_context(
-    suffix: str, parent_context: Optional[LoggingContext] = None
-) -> LoggingContext:
+def nested_logging_context(suffix: str) -> LoggingContext:
     """Creates a new logging context as a child of another.
 
     The nested logging context will have a 'request' made up of the parent context's
@@ -632,20 +635,23 @@ def nested_logging_context(
             # ... do stuff
 
     Args:
-        suffix (str): suffix to add to the parent context's 'request'.
-        parent_context (LoggingContext|None): parent context. Will use the current context
-            if None.
+        suffix: suffix to add to the parent context's 'request'.
 
     Returns:
         LoggingContext: new logging context.
     """
-    if parent_context is not None:
-        context = parent_context  # type: LoggingContextOrSentinel
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Starting nested logging context from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+        prefix = ""
     else:
-        context = current_context()
-    return LoggingContext(
-        parent_context=context, request=str(context.request) + "-" + suffix
-    )
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
+        prefix = str(parent_context.request)
+    return LoggingContext(parent_context=parent_context, request=prefix + "-" + suffix)
 
 
 def preserve_fn(f):
@@ -822,10 +828,18 @@ def defer_to_threadpool(reactor, threadpool, f, *args, **kwargs):
         Deferred: A Deferred which fires a callback with the result of `f`, or an
             errback if `f` throws an exception.
     """
-    logcontext = current_context()
+    curr_context = current_context()
+    if not curr_context:
+        logger.warning(
+            "Calling defer_to_threadpool from sentinel context: metrics will be lost"
+        )
+        parent_context = None
+    else:
+        assert isinstance(curr_context, LoggingContext)
+        parent_context = curr_context
 
     def g():
-        with LoggingContext(parent_context=logcontext):
+        with LoggingContext(parent_context=parent_context):
             return f(*args, **kwargs)
 
     return make_deferred_yieldable(threads.deferToThreadPool(reactor, threadpool, g))
diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html
new file mode 100644
index 0000000000..f53c9cd679
--- /dev/null
+++ b/synapse/res/templates/sso_login_idp_picker.html
@@ -0,0 +1,28 @@
+<!DOCTYPE html>
+<html lang="en">
+    <head>
+        <meta charset="UTF-8">
+        <link rel="stylesheet" href="/_matrix/static/client/login/style.css">
+        <title>{{server_name | e}} Login</title>
+    </head>
+    <body>
+        <div id="container">
+            <h1 id="title">{{server_name | e}} Login</h1>
+            <div class="login_flow">
+                <p>Choose one of the following identity providers:</p>
+            <form>
+                <input type="hidden" name="redirectUrl" value="{{redirect_url | e}}">
+                <ul class="radiobuttons">
+{% for p in providers %}
+                    <li>
+                        <input type="radio" name="idp" id="prov{{loop.index}}" value="{{p.idp_id}}">
+                        <label for="prov{{loop.index}}">{{p.idp_name | e}}</label>
+                    </li>
+{% endfor %}
+                </ul>
+                <input type="submit" class="button button--full-width" id="button-submit" value="Submit">
+            </form>
+            </div>
+        </div>
+    </body>
+</html>
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5f4c6703db..ebc346105b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -311,48 +311,31 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
+class SsoRedirectServlet(RestServlet):
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        elif hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        elif hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+
     async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -379,40 +362,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..9b9514632f 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,15 +14,20 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
 from synapse.api.urls import CLIENT_API_PREFIX
+from synapse.handlers.sso import SsoIdentityProvider
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,7 +40,7 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -85,31 +90,20 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
 
+            if self._cas_enabled:
+                sso_auth_provider = self._cas_handler  # type: SsoIdentityProvider
             elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
+                sso_auth_provider = self._saml_handler
             elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
+                sso_auth_provider = self._oidc_handler
             else:
                 raise SynapseError(400, "Homeserver not configured for SSO.")
 
+            sso_redirect_url = await sso_auth_provider.handle_redirect_request(
+                request, None, session
+            )
+
             html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
 
         else:
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..e5b720bbca
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,82 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    finish_request,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+    """IdP picker resource.
+
+    This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+    which prompts the user to choose an Identity Provider from the list.
+    """
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._sso_handler = hs.get_sso_handler()
+        self._sso_login_idp_picker_template = (
+            hs.config.sso.sso_login_idp_picker_template
+        )
+        self._server_name = hs.hostname
+
+    async def _async_render_GET(self, request: SynapseRequest) -> None:
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
+        idp = parse_string(request, "idp", required=False)
+
+        # if we need to pick an IdP, do so
+        if not idp:
+            return await self._serve_id_picker(request, client_redirect_url)
+
+        # otherwise, redirect to the IdP's redirect URI
+        providers = self._sso_handler.get_identity_providers()
+        auth_provider = providers.get(idp)
+        if not auth_provider:
+            logger.info("Unknown idp %r", idp)
+            self._sso_handler.render_error(
+                request, "unknown_idp", "Unknown identity provider ID"
+            )
+            return
+
+        sso_url = await auth_provider.handle_redirect_request(
+            request, client_redirect_url.encode("utf8")
+        )
+        logger.info("Redirecting to %s", sso_url)
+        request.redirect(sso_url)
+        finish_request(request)
+
+    async def _serve_id_picker(
+        self, request: SynapseRequest, client_redirect_url: str
+    ) -> None:
+        # otherwise, serve up the IdP picker
+        providers = self._sso_handler.get_identity_providers()
+        html = self._sso_login_idp_picker_template.render(
+            redirect_url=client_redirect_url,
+            server_name=self._server_name,
+            providers=providers.values(),
+        )
+        respond_with_html(request, 200, html)
diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css
index 83e4f6abc8..dd76714a92 100644
--- a/synapse/static/client/login/style.css
+++ b/synapse/static/client/login/style.css
@@ -31,6 +31,11 @@ form {
     margin: 10px 0 0 0;
 }
 
+ul.radiobuttons {
+    text-align: left;
+    list-style: none;
+}
+
 /*
  * Add some padding to the viewport.
  */
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index d1b5760c2c..b70ca3087b 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -42,7 +42,6 @@ from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import (
     LoggingContext,
-    LoggingContextOrSentinel,
     current_context,
     make_deferred_yieldable,
 )
@@ -671,12 +670,15 @@ class DatabasePool:
         Returns:
             The result of func
         """
-        parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
-        if not parent_context:
+        curr_context = current_context()
+        if not curr_context:
             logger.warning(
                 "Starting db connection from sentinel context: metrics will be lost"
             )
             parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
 
         start_time = monotonic_time()
 
diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py
index ffdea0de8d..24123d5cc4 100644
--- a/synapse/util/metrics.py
+++ b/synapse/util/metrics.py
@@ -108,7 +108,15 @@ class Measure:
     def __init__(self, clock, name):
         self.clock = clock
         self.name = name
-        parent_context = current_context()
+        curr_context = current_context()
+        if not curr_context:
+            logger.warning(
+                "Starting metrics collection from sentinel context: metrics will be lost"
+            )
+            parent_context = None
+        else:
+            assert isinstance(curr_context, LoggingContext)
+            parent_context = curr_context
         self._logging_context = LoggingContext(
             "Measure[%s]" % (self.name,), parent_context
         )