diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 00eae92052..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,10 +19,21 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Tuple,
+ Union,
+)
import attr
-import bcrypt # type: ignore[import]
+import bcrypt
import pymacaroons
from synapse.api.constants import LoginType
@@ -49,6 +61,9 @@ from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -149,11 +164,7 @@ class SsoLoginExtraAttributes:
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
@@ -164,13 +175,20 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
+ # we can't use hs.get_module_api() here, because to do so will create an
+ # import loop.
+ #
+ # TODO: refactor this class to separate the lower-level stuff that
+ # ModuleApi can use from the higher-level stuff that uses ModuleApi, as
+ # better way to break the loop
account_handler = ModuleApi(hs, self)
+
self.password_providers = [
- module(config=config, account_handler=account_handler)
+ PasswordProvider.load(module, config, account_handler)
for module, config in hs.config.password_providers
]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -184,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
+
+ # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if self._password_enabled:
+ if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
+
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
+
+ if not self._password_enabled:
+ login_types.remove(LoginType.PASSWORD)
+
self._supported_login_types = login_types
+
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -209,10 +235,17 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # Ratelimitier for failed /login attempts
+ self._failed_login_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
+
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
- if hs.config.worker_app is None:
+ if hs.config.run_background_tasks:
self._clock.looping_call(
run_as_background_process,
5 * 60 * 1000,
@@ -463,9 +496,7 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
- user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
- 0
- ].decode("ascii", "surrogateescape")
+ user_agent = request.get_user_agent("")
await self.store.add_user_agent_ip_to_ui_auth_session(
session.session_id, user_agent, clientip
@@ -623,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
- # build a v1-login-style dict out of the authdict and fall back to the
- # v1 code
- user_id = authdict.get("user")
-
- if user_id is None:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
- (canonical_id, callback) = await self.validate_login(user_id, authdict)
+ # fall back to the v1 login flow
+ canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -679,13 +704,17 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
- self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
+ ) -> str:
"""
Creates a new access token for the user with the given user ID.
The user is assumed to have been authenticated by some other
- machanism (e.g. CAS), and the user_id converted to the canonical case.
+ mechanism (e.g. CAS), and the user_id converted to the canonical case.
The device will be recorded in the table if it is not there already.
@@ -706,13 +735,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
- logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+ if puppets_user_id:
+ logger.info(
+ "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+ )
+ else:
+ logger.info(
+ "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+ )
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
- user_id, access_token, device_id, valid_until_ms
+ user_id=user_id,
+ token=access_token,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
@@ -789,17 +830,17 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, username: str, login_submission: Dict[str, Any]
+ self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
- Also used by the user-interactive auth flow to validate
- m.login.password auth types.
+ Also used by the user-interactive auth flow to validate auth types which don't
+ have an explicit UIA handler, including m.password.auth.
Args:
- username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
+ ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -808,38 +849,160 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
-
- if username.startswith("@"):
- qualified_user_id = username
- else:
- qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
login_type = login_submission.get("type")
- known_login_type = False
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+ # ideally, we wouldn't be checking the identifier unless we know we have a login
+ # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+ #
+ # But the auth providers' check_auth interface requires a username, so in
+ # practice we can only support login methods which we can map to a username
+ # anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
-
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
- if not password:
- raise SynapseError(400, "Missing parameter: password")
+ if not isinstance(password, str):
+ raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
- for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- is_valid = await provider.check_password(qualified_user_id, password)
- if is_valid:
- return qualified_user_id, None
+ # map old-school login fields into new-school "identifier" fields.
+ identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+ login_submission
+ )
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
+ # convert phone type identifiers to generic threepids
+ if identifier_dict["type"] == "m.id.phone":
+ identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+ # convert threepid identifiers to user IDs
+ if identifier_dict["type"] == "m.id.thirdparty":
+ address = identifier_dict.get("address")
+ medium = identifier_dict.get("medium")
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ (medium, address), update=False
+ )
+
+ # Check for login providers that support 3pid login types
+ if login_type == LoginType.PASSWORD:
+ # we've already checked that there is a (valid) password field
+ assert isinstance(password, str)
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.check_password_provider_3pid(medium, address, password)
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ return canonical_user_id, callback_3pid
+
+ # No password providers were able to handle this 3pid
+ # Check local store
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
+ )
+ if not user_id:
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ (medium, address)
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ identifier_dict = {"type": "m.id.user", "user": user_id}
+
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier_dict["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+
+ username = identifier_dict.get("user")
+ if not username:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
+ )
+ try:
+ return await self._validate_userid_login(username, login_submission)
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower()
+ )
+ raise
+
+ async def _validate_userid_login(
+ self, username: str, login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ """Helper for validate_login
+
+ Handles login, once we've mapped 3pids onto userids
+
+ Args:
+ username: the username, from the identifier dict
+ login_submission: the whole of the login submission
+ (including 'type' and other relevant fields)
+ Returns:
+ A tuple of the canonical user id, and optional callback
+ to be called once the access token and device id are issued
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
+ known_login_type = False
+
+ for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -864,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
+ # we've already checked that there is a (valid) password field
+ password = login_submission["password"]
+ assert isinstance(password, str)
+
canonical_user_id = await self._check_local_password(
- qualified_user_id, password # type: ignore
+ qualified_user_id, password
)
if canonical_user_id:
@@ -903,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -973,21 +1128,16 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=str(user_info["user"]),
- device_id=user_info["device_id"],
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
- if user_info["token_id"] is not None:
+ if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token(
- str(user_info["user"]), (user_info["token_id"],)
+ user_info.user_id, (user_info.token_id,)
)
async def delete_access_tokens_for_user(
@@ -1011,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1073,7 +1222,7 @@ class AuthHandler(BaseHandler):
if medium == "email":
address = canonicalise_email(address)
- identity_handler = self.hs.get_handlers().identity_handler
+ identity_handler = self.hs.get_identity_handler()
result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
@@ -1115,20 +1264,22 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash.
"""
- def _do_validate_hash():
+ def _do_validate_hash(checked_hash: bytes):
# Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw(
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
- stored_hash,
+ checked_hash,
)
if stored_hash:
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
- return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return await defer_to_thread(
+ self.hs.get_reactor(), _do_validate_hash, stored_hash
+ )
else:
return False
@@ -1337,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
|