diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 5163afd86c..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = []
- for module, config in hs.config.password_providers:
- try:
- self.password_providers.append(
- module(config=config, account_handler=account_handler)
- )
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
+ self.password_providers = [
+ PasswordProvider.load(module, config, account_handler)
+ for module, config in hs.config.password_providers
+ ]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
+
+ # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if self._password_enabled:
+ if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
+
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
+
+ if not self._password_enabled:
+ login_types.remove(LoginType.PASSWORD)
+
self._supported_login_types = login_types
+
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # Ratelimitier for failed /login attempts
+ self._failed_login_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
+
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
- # build a v1-login-style dict out of the authdict and fall back to the
- # v1 code
- user_id = authdict.get("user")
-
- if user_id is None:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
- (canonical_id, callback) = await self.validate_login(user_id, authdict)
+ # fall back to the v1 login flow
+ canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -824,17 +830,17 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, username: str, login_submission: Dict[str, Any]
+ self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
- Also used by the user-interactive auth flow to validate
- m.login.password auth types.
+ Also used by the user-interactive auth flow to validate auth types which don't
+ have an explicit UIA handler, including m.password.auth.
Args:
- username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
+ ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -843,38 +849,160 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
-
- if username.startswith("@"):
- qualified_user_id = username
- else:
- qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
login_type = login_submission.get("type")
- known_login_type = False
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+ # ideally, we wouldn't be checking the identifier unless we know we have a login
+ # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+ #
+ # But the auth providers' check_auth interface requires a username, so in
+ # practice we can only support login methods which we can map to a username
+ # anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
-
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
- if not password:
- raise SynapseError(400, "Missing parameter: password")
+ if not isinstance(password, str):
+ raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
- for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- is_valid = await provider.check_password(qualified_user_id, password)
- if is_valid:
- return qualified_user_id, None
+ # map old-school login fields into new-school "identifier" fields.
+ identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+ login_submission
+ )
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
+ # convert phone type identifiers to generic threepids
+ if identifier_dict["type"] == "m.id.phone":
+ identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+ # convert threepid identifiers to user IDs
+ if identifier_dict["type"] == "m.id.thirdparty":
+ address = identifier_dict.get("address")
+ medium = identifier_dict.get("medium")
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ (medium, address), update=False
+ )
+ # Check for login providers that support 3pid login types
+ if login_type == LoginType.PASSWORD:
+ # we've already checked that there is a (valid) password field
+ assert isinstance(password, str)
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.check_password_provider_3pid(medium, address, password)
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ return canonical_user_id, callback_3pid
+
+ # No password providers were able to handle this 3pid
+ # Check local store
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
+ )
+ if not user_id:
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ (medium, address)
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ identifier_dict = {"type": "m.id.user", "user": user_id}
+
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier_dict["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+
+ username = identifier_dict.get("user")
+ if not username:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
+ )
+
+ try:
+ return await self._validate_userid_login(username, login_submission)
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower()
+ )
+ raise
+
+ async def _validate_userid_login(
+ self, username: str, login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ """Helper for validate_login
+
+ Handles login, once we've mapped 3pids onto userids
+
+ Args:
+ username: the username, from the identifier dict
+ login_submission: the whole of the login submission
+ (including 'type' and other relevant fields)
+ Returns:
+ A tuple of the canonical user id, and optional callback
+ to be called once the access token and device id are issued
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
+ known_login_type = False
+
+ for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -899,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
+ # we've already checked that there is a (valid) password field
+ password = login_submission["password"]
+ assert isinstance(password, str)
+
canonical_user_id = await self._check_local_password(
- qualified_user_id, password # type: ignore
+ qualified_user_id, password
)
if canonical_user_id:
@@ -938,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -1008,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1046,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1374,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..9b3c6b4551 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
- self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ self.hs.get_clock().time_msec()
+ + self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 78c4e94a9d..55c4377890 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
return UserAttributes(**attributes)
+ async def grandfather_existing_users() -> Optional[str]:
+ if self._allow_existing_users:
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ return previously_registered_user_id
+
+ return None
+
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
- self._allow_existing_users,
+ grandfather_existing_users,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 426b58da9e..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -299,17 +299,22 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id: str) -> None:
- """Purge the given room from the database"""
+ async def purge_room(self, room_id: str, force: bool = False) -> None:
+ """Purge the given room from the database.
+
+ Args:
+ room_id: room to be purged
+ force: set true to skip checking for joined users.
+ """
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
- joined = await self.store.is_host_joined(room_id, self._server_name)
-
- if joined:
- raise SynapseError(400, "Users are still joined to this room")
+ if not force:
+ joined = await self.store.is_host_joined(room_id, self._server_name)
+ if joined:
+ raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4e693a419e..4d8ffe8821 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -366,7 +366,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles or requester.shadow_banned:
+ # allow the server notices mxid to set room-level profile
+ is_requester_server_notices_user = (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ )
+
+ if (
+ not self.allow_per_room_profiles and not is_requester_server_notices_user
+ ) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 34db10ffe4..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
- emails=result.get("emails"),
+ emails=result.get("emails", []),
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
+ async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
- )
return registered_user_id
+ return None
+
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
)
def expire_sessions(self):
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d963082210..f42b90e1bc 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- allow_existing_users: bool = False,
+ grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
if it has that matrix ID is returned regardless of the current mapping
logic.
+ If a callable is provided for grandfathering users, it is called and can
+ potentially return a matrix ID to use. If it does, the SSO ID is linked to
+ this matrix ID for subsequent calls.
+
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
@@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
- If allow_existing_users is true the mapping function is only called once
- and results in:
-
- 1. The use of a previously registered matrix ID. In this case, the
- SSO ID is linked to the matrix ID. (Note it is possible that
- other SSO IDs are linked to the same matrix ID.)
- 2. An unused localpart, in which case the user is registered (as
- discussed above).
- 3. An error if the generated localpart matches multiple pre-existing
- matrix IDs. Generally this should not happen.
-
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
- allow_existing_users: True if the localpart returned from the
- mapping provider can be linked to an existing matrix ID.
+ grandfather_existing_users: A callable which can return an previously
+ existing matrix ID. The SSO ID is then linked to the returned
+ matrix ID.
Returns:
The user ID associated with the SSO response.
@@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
if previously_registered_user_id:
return previously_registered_user_id
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
@@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- # Note, if allow_existing_users is true then the loop is guaranteed
- # to end on the first iteration: either by matching an existing user,
- # raising an error, or registering a new user. See the docstring for
- # more in-depth an explanation.
- if users and allow_existing_users:
- # If an existing matrix ID is returned, then use it.
- if len(users) == 1:
- previously_registered_user_id = next(iter(users))
- elif user_id in users:
- previously_registered_user_id = user_id
- else:
- # Do not attempt to continue generating Matrix IDs.
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, users
- )
- )
-
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
- )
-
- return previously_registered_user_id
-
- elif not users:
+ if not await self.store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
|