diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..21e568f226 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.http import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -56,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -193,39 +196,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
-
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
- if hs.config.password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types = set()
+ if self._password_localdb_enabled:
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
-
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -339,7 +330,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +345,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +361,41 @@ class AuthHandler(BaseHandler):
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -831,7 +860,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1003,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1058,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1052,7 +1081,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1303,15 +1332,14 @@ class AuthHandler(BaseHandler):
)
async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
+ self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ session_id: The ID of the user-interactive auth session.
request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1327,7 +1355,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1355,7 +1383,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1609,6 +1637,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..fd8de8696d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
async def threepid_from_creds(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 11420ea996..cbac43c536 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e))
return
+ # first check if we're doing a UIA
+ if ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
+
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url, extra_attributes
+ )
def _generate_oidc_session_token(
self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
The mxid of the user
"""
try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
grandfather_existing_users,
)
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..e850e45e46 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..94b5610acd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
+ result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 930047e730..7583418946 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -440,6 +440,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
+ ratelimit=False,
)
# Transfer membership events
@@ -608,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -735,6 +736,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
+ ratelimit=ratelimit,
)
if "name" in config:
@@ -838,6 +840,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
+ ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -884,7 +887,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
- ratelimit=False,
+ ratelimit=ratelimit,
content=creator_join_profile,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4d8ffe8821..bea028b2bf 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -204,7 +204,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
- if newly_joined:
+ if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -428,7 +428,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
@@ -508,17 +508,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- time_now_s = self.clock.time()
- (
- allowed,
- time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ if ratelimit:
+ time_now_s = self.clock.time()
+ (
+ allowed,
+ time_allowed,
+ ) = self._join_rate_limiter_remote.can_requester_do_action(
+ requester,
)
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..f2ca1ddb53 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
@@ -183,6 +179,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -206,14 +220,7 @@ class SamlHandler(BaseHandler):
self._sso_handler.render_error(request, "mapping_error", str(e))
return
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+ await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
@@ -239,16 +246,10 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +295,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..112a7d5b2c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,10 +17,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
+from twisted.web.http import Request
+
from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -42,14 +44,19 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +102,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -169,24 +176,38 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
-
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
- )
return previously_registered_user_id
- # Otherwise, generate a new user.
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+ user_id = await self._register_mapped_user(
+ attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
+ )
+ return user_id
+
+ async def _call_attribute_mapper(
+ self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -214,8 +235,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,7 +245,16 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart):
@@ -238,7 +268,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..f263a638f8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
|