diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index c705de5694..fddca19223 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import string
from typing import Optional, Type
import attr
@@ -38,7 +39,7 @@ class OIDCConfig(Config):
oidc_config = config.get("oidc_config")
if oidc_config and oidc_config.get("enabled", False):
- validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
+ validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
self.oidc_provider = _parse_oidc_config_dict(oidc_config)
if not self.oidc_provider:
@@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
"required": ["issuer", "client_id", "client_secret"],
"properties": {
+ "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
+ "idp_name": {"type": "string"},
"discover": {"type": "boolean"},
"issuer": {"type": "string"},
"client_id": {"type": "string"},
@@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
"methods: %s" % (", ".join(missing_methods),)
)
+ # MSC2858 will appy certain limits in what can be used as an IdP id, so let's
+ # enforce those limits now.
+ idp_id = oidc_config.get("idp_id", "oidc")
+ valid_idp_chars = set(string.ascii_letters + string.digits + "-._~")
+
+ if any(c not in valid_idp_chars for c in idp_id):
+ raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"')
+
return OidcProviderConfig(
+ idp_id=idp_id,
+ idp_name=oidc_config.get("idp_name", "OIDC"),
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
@@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
)
-@attr.s
+@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
+ # a unique identifier for this identity provider. Used in the 'user_external_ids'
+ # table, as well as the query/path parameter used in the login protocol.
+ idp_id = attr.ib(type=str)
+
+ # user-facing name for this identity provider.
+ idp_name = attr.ib(type=str)
+
# whether the OIDC discovery mechanism is used to discover endpoints
discover = attr.ib(type=bool)
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 1aeb1c5c92..366f0d4698 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -37,6 +37,7 @@ class SSOConfig(Config):
self.sso_error_template,
sso_account_deactivated_template,
sso_auth_success_template,
+ self.sso_auth_bad_user_template,
) = self.read_templates(
[
"sso_login_idp_picker.html",
@@ -45,6 +46,7 @@ class SSOConfig(Config):
"sso_error.html",
"sso_account_deactivated.html",
"sso_auth_success.html",
+ "sso_auth_bad_user.html",
],
template_dir,
)
@@ -160,6 +162,14 @@ class SSOConfig(Config):
#
# This template has no additional variables.
#
+ # * HTML page shown after a user-interactive authentication session which
+ # does not map correctly onto the expected user: 'sso_auth_bad_user.html'.
+ #
+ # When rendering, this template is given the following variables:
+ # * server_name: the homeserver's name.
+ # * user_id_to_verify: the MXID of the user that we are trying to
+ # validate.
+ #
# * HTML page shown during single sign-on if a deactivated user (according to Synapse's database)
# attempts to login: 'sso_account_deactivated.html'.
#
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4f881a439a..18cd2b62f0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -263,10 +263,6 @@ class AuthHandler(BaseHandler):
# authenticating for an operation to occur on their account.
self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
- # The following template is shown after a successful user interactive
- # authentication session. It tells the user they can close the window.
- self._sso_auth_success_template = hs.config.sso_auth_success_template
-
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -1394,27 +1390,6 @@ class AuthHandler(BaseHandler):
description=session.description, redirect_url=redirect_url,
)
- async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: Request,
- ):
- """Having figured out a mxid for this user, complete the HTTP request
-
- Args:
- registered_user_id: The registered user ID to complete SSO login for.
- session_id: The ID of the user-interactive auth session.
- request: The request to complete.
- """
- # Mark the stage of the authentication as successful.
- # Save the user who authenticated with SSO, this will be used to ensure
- # that the account be modified is also the person who logged in.
- await self.store.mark_ui_auth_stage_complete(
- session_id, LoginType.SSO, registered_user_id
- )
-
- # Render the HTML and return.
- html = self._sso_auth_success_template
- respond_with_html(request, 200, html)
-
async def complete_sso_login(
self,
registered_user_id: str,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index d6347bb1b8..f63a90ec5c 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -175,7 +175,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except MacaroonDeserializationException as e:
+ except (MacaroonDeserializationException, ValueError) as e:
logger.exception("Invalid session")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -253,10 +253,10 @@ class OidcProvider:
self._server_name = hs.config.server_name # type: str
# identifier for the external_ids table
- self.idp_id = "oidc"
+ self.idp_id = provider.idp_id
# user-facing name of this auth provider
- self.idp_name = "OIDC"
+ self.idp_name = provider.idp_name
self._sso_handler = hs.get_sso_handler()
@@ -656,6 +656,7 @@ class OidcProvider:
cookie = self._token_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
+ idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
@@ -924,6 +925,7 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = session")
macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
@@ -952,6 +954,9 @@ class OidcSessionTokenGenerator:
Returns:
The data extracted from the session cookie
+
+ Raises:
+ ValueError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -960,6 +965,7 @@ class OidcSessionTokenGenerator:
v.satisfy_exact("type = session")
v.satisfy_exact("state = %s" % (state,))
v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
@@ -968,9 +974,9 @@ class OidcSessionTokenGenerator:
v.verify(macaroon, self._macaroon_secret_key)
- # Extract the `nonce`, `client_redirect_url`, and maybe the
- # `ui_auth_session_id` from the token.
+ # Extract the session data from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
@@ -983,6 +989,7 @@ class OidcSessionTokenGenerator:
return OidcSessionData(
nonce=nonce,
+ idp_id=idp_id,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
)
@@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator:
The extracted value
Raises:
- Exception: if the caveat was not in the macaroon
+ ValueError: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
@@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator:
class OidcSessionData:
"""The attributes which are stored in a OIDC session cookie"""
+ # the Identity Provider being used
+ idp_id = attr.ib(type=str)
+
# The `nonce` parameter passed to the OIDC provider.
nonce = attr.ib(type=str)
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d096e0b091..dcc85e9871 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -22,7 +22,9 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
+from synapse.api.constants import LoginType
from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
@@ -146,8 +148,13 @@ class SsoHandler:
self._store = hs.get_datastore()
self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
- self._error_template = hs.config.sso_error_template
self._auth_handler = hs.get_auth_handler()
+ self._error_template = hs.config.sso_error_template
+ self._bad_user_template = hs.config.sso_auth_bad_user_template
+
+ # The following template is shown after a successful user interactive
+ # authentication session. It tells the user they can close the window.
+ self._sso_auth_success_template = hs.config.sso_auth_success_template
# a lock on the mappings
self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
@@ -577,19 +584,45 @@ class SsoHandler:
auth_provider_id, remote_user_id,
)
+ user_id_to_verify = await self._auth_handler.get_session_data(
+ ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID
+ ) # type: str
+
if not user_id:
logger.warning(
"Remote user %s/%s has not previously logged in here: UIA will fail",
auth_provider_id,
remote_user_id,
)
- # Let the UIA flow handle this the same as if they presented creds for a
- # different user.
- user_id = ""
+ elif user_id != user_id_to_verify:
+ logger.warning(
+ "Remote user %s/%s mapped onto incorrect user %s: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ user_id,
+ )
+ else:
+ # success!
+ # Mark the stage of the authentication as successful.
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, user_id
+ )
+
+ # Render the HTML confirmation page and return.
+ html = self._sso_auth_success_template
+ respond_with_html(request, 200, html)
+ return
+
+ # the user_id didn't match: mark the stage of the authentication as unsuccessful
+ await self._store.mark_ui_auth_stage_complete(
+ ui_auth_session_id, LoginType.SSO, ""
+ )
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
+ # render an error page.
+ html = self._bad_user_template.render(
+ server_name=self._server_name, user_id_to_verify=user_id_to_verify,
)
+ respond_with_html(request, 200, html)
async def check_username_availability(
self, localpart: str, session_id: str,
diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html
new file mode 100644
index 0000000000..3611191bf9
--- /dev/null
+++ b/synapse/res/templates/sso_auth_bad_user.html
@@ -0,0 +1,18 @@
+<html>
+<head>
+ <title>Authentication Failed</title>
+</head>
+ <body>
+ <div>
+ <p>
+ We were unable to validate your <tt>{{server_name | e}}</tt> account via
+ single-sign-on (SSO), because the SSO Identity Provider returned
+ different details than when you logged in.
+ </p>
+ <p>
+ Try the operation again, and ensure that you use the same details on
+ the Identity Provider as when you log into your account.
+ </p>
+ </div>
+ </body>
+</html>
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 1b6ccd51c8..c128889bf9 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -25,6 +25,7 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
+from synapse.storage.engines import PostgresEngine
from synapse.storage.types import Cursor
from synapse.types import JsonDict
from synapse.util import json_encoder
@@ -513,21 +514,35 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore):
for user_chunk in batch_iter(user_ids, 100):
clause, params = make_in_list_sql_clause(
- txn.database_engine, "k.user_id", user_chunk
- )
- sql = (
- """
- SELECT k.user_id, k.keytype, k.keydata, k.stream_id
- FROM e2e_cross_signing_keys k
- INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
- FROM e2e_cross_signing_keys
- GROUP BY user_id, keytype) s
- USING (user_id, stream_id, keytype)
- WHERE
- """
- + clause
+ txn.database_engine, "user_id", user_chunk
)
+ # Fetch the latest key for each type per user.
+ if isinstance(self.database_engine, PostgresEngine):
+ # The `DISTINCT ON` clause will pick the *first* row it
+ # encounters, so ordering by stream ID desc will ensure we get
+ # the latest key.
+ sql = """
+ SELECT DISTINCT ON (user_id, keytype) user_id, keytype, keydata, stream_id
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ ORDER BY user_id, keytype, stream_id DESC
+ """ % {
+ "clause": clause
+ }
+ else:
+ # SQLite has special handling for bare columns when using
+ # MIN/MAX with a `GROUP BY` clause where it picks the value from
+ # a row that matches the MIN/MAX.
+ sql = """
+ SELECT user_id, keytype, keydata, MAX(stream_id)
+ FROM e2e_cross_signing_keys
+ WHERE %(clause)s
+ GROUP BY user_id, keytype
+ """ % {
+ "clause": clause
+ }
+
txn.execute(sql, params)
rows = self.db_pool.cursor_to_dict(txn)
diff --git a/synapse/util/iterutils.py b/synapse/util/iterutils.py
index f7b4857a84..6ef2b008a4 100644
--- a/synapse/util/iterutils.py
+++ b/synapse/util/iterutils.py
@@ -92,7 +92,7 @@ def sorted_topologically(
node = heapq.heappop(zero_degree)
yield node
- for edge in reverse_graph[node]:
+ for edge in reverse_graph.get(node, []):
if edge in degree_map:
degree_map[edge] -= 1
if degree_map[edge] == 0:
|