diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 273 |
1 files changed, 195 insertions, 78 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 9cef9d184b..7a0ba6ef35 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -13,13 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from twisted.internet import defer from ._base import BaseHandler from synapse.api.constants import LoginType -from synapse.types import UserID from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError +from synapse.module_api import ModuleApi +from synapse.types import UserID from synapse.util.async import run_on_reactor from synapse.util.caches.expiringcache import ExpiringCache @@ -63,10 +63,7 @@ class AuthHandler(BaseHandler): reset_expiry_on_get=True, ) - account_handler = _AccountHandler( - hs, check_user_exists=self.check_user_exists - ) - + account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) for module, config in hs.config.password_providers @@ -75,14 +72,24 @@ class AuthHandler(BaseHandler): logger.info("Extra password_providers: %r", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? - self.device_handler = hs.get_device_handler() self.macaroon_gen = hs.get_macaroon_generator() + self._password_enabled = hs.config.password_enabled + + login_types = set() + if self._password_enabled: + login_types.add(LoginType.PASSWORD) + for provider in self.password_providers: + if hasattr(provider, "get_supported_login_types"): + login_types.update( + provider.get_supported_login_types().keys() + ) + self._supported_login_types = frozenset(login_types) @defer.inlineCallbacks def check_auth(self, flows, clientdict, clientip): """ Takes a dictionary sent by the client in the login / registration - protocol and handles the login flow. + protocol and handles the User-Interactive Auth flow. As a side effect, this function fills in the 'creds' key on the user's session with a map, which maps each auth-type (str) to the relevant @@ -260,16 +267,19 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault('serverdict', {}).get(key, default) + @defer.inlineCallbacks def _check_password_auth(self, authdict, _): if "user" not in authdict or "password" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) user_id = authdict["user"] password = authdict["password"] - if not user_id.startswith('@'): - user_id = UserID(user_id, self.hs.hostname).to_string() - return self._check_password(user_id, password) + (canonical_id, callback) = yield self.validate_login(user_id, { + "type": LoginType.PASSWORD, + "password": password, + }) + defer.returnValue(canonical_id) @defer.inlineCallbacks def _check_recaptcha(self, authdict, clientip): @@ -398,26 +408,8 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - def validate_password_login(self, user_id, password): - """ - Authenticates the user with their username and password. - - Used only by the v1 login API. - - Args: - user_id (str): complete @user:id - password (str): Password - Returns: - defer.Deferred: (str) canonical user id - Raises: - StoreError if there was a problem accessing the database - LoginError if there was an authentication problem. - """ - return self._check_password(user_id, password) - @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id=None, - initial_display_name=None): + def get_access_token_for_user_id(self, user_id, device_id=None): """ Creates a new access token for the user with the given user ID. @@ -431,13 +423,10 @@ class AuthHandler(BaseHandler): device_id (str|None): the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - initial_display_name (str): display name to associate with the - device if it needs re-registering Returns: The access token for the user's session. Raises: StoreError if there was a problem storing the token. - LoginError if there was an authentication problem. """ logger.info("Logging in user %s on device %s", user_id, device_id) access_token = yield self.issue_access_token(user_id, device_id) @@ -447,9 +436,11 @@ class AuthHandler(BaseHandler): # really don't want is active access_tokens without a record of the # device, so we double-check it here. if device_id is not None: - yield self.device_handler.check_device_registered( - user_id, device_id, initial_display_name - ) + try: + yield self.store.get_device(user_id, device_id) + except StoreError: + yield self.store.delete_access_token(access_token) + raise StoreError(400, "Login raced against device deletion") defer.returnValue(access_token) @@ -501,29 +492,115 @@ class AuthHandler(BaseHandler): ) defer.returnValue(result) + def get_supported_login_types(self): + """Get a the login types supported for the /login API + + By default this is just 'm.login.password' (unless password_enabled is + False in the config file), but password auth providers can provide + other login types. + + Returns: + Iterable[str]: login types + """ + return self._supported_login_types + @defer.inlineCallbacks - def _check_password(self, user_id, password): - """Authenticate a user against the LDAP and local databases. + def validate_login(self, username, login_submission): + """Authenticates the user for the /login API - user_id is checked case insensitively against the local database, but - will throw if there are multiple inexact matches. + Also used by the user-interactive auth flow to validate + m.login.password auth types. Args: - user_id (str): complete @user:id + username (str): username supplied by the user + login_submission (dict): the whole of the login submission + (including 'type' and other relevant fields) Returns: - (str) the canonical_user_id + Deferred[str, func]: canonical user id, and optional callback + to be called once the access token and device id are issued Raises: - LoginError if login fails + StoreError if there was a problem accessing the database + SynapseError if there was a problem with the request + LoginError if there was an authentication problem. """ + + if username.startswith('@'): + qualified_user_id = username + else: + qualified_user_id = UserID( + username, self.hs.hostname + ).to_string() + + login_type = login_submission.get("type") + known_login_type = False + + # special case to check for "password" for the check_password interface + # for the auth providers + password = login_submission.get("password") + if login_type == LoginType.PASSWORD: + if not self._password_enabled: + raise SynapseError(400, "Password login has been disabled.") + if not password: + raise SynapseError(400, "Missing parameter: password") + for provider in self.password_providers: - is_valid = yield provider.check_password(user_id, password) - if is_valid: - defer.returnValue(user_id) + if (hasattr(provider, "check_password") + and login_type == LoginType.PASSWORD): + known_login_type = True + is_valid = yield provider.check_password( + qualified_user_id, password, + ) + if is_valid: + defer.returnValue(qualified_user_id) + + if (not hasattr(provider, "get_supported_login_types") + or not hasattr(provider, "check_auth")): + # this password provider doesn't understand custom login types + continue + + supported_login_types = provider.get_supported_login_types() + if login_type not in supported_login_types: + # this password provider doesn't understand this login type + continue + + known_login_type = True + login_fields = supported_login_types[login_type] + + missing_fields = [] + login_dict = {} + for f in login_fields: + if f not in login_submission: + missing_fields.append(f) + else: + login_dict[f] = login_submission[f] + if missing_fields: + raise SynapseError( + 400, "Missing parameters for login type %s: %s" % ( + login_type, + missing_fields, + ), + ) + + result = yield provider.check_auth( + username, login_type, login_dict, + ) + if result: + if isinstance(result, str): + result = (result, None) + defer.returnValue(result) + + if login_type == LoginType.PASSWORD: + known_login_type = True + + canonical_user_id = yield self._check_local_password( + qualified_user_id, password, + ) - canonical_user_id = yield self._check_local_password(user_id, password) + if canonical_user_id: + defer.returnValue((canonical_user_id, None)) - if canonical_user_id: - defer.returnValue(canonical_user_id) + if not known_login_type: + raise SynapseError(400, "Unknown login type %s" % login_type) # unknown username or invalid password. We raise a 403 here, but note # that if we're doing user-interactive login, it turns all LoginErrors @@ -584,14 +661,81 @@ class AuthHandler(BaseHandler): if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) raise e - yield self.store.user_delete_access_tokens( - user_id, except_access_token_id + yield self.delete_access_tokens_for_user( + user_id, except_token_id=except_access_token_id, ) yield self.hs.get_pusherpool().remove_pushers_by_user( user_id, except_access_token_id ) @defer.inlineCallbacks + def deactivate_account(self, user_id): + """Deactivate a user's account + + Args: + user_id (str): ID of user to be deactivated + + Returns: + Deferred + """ + # FIXME: Theoretically there is a race here wherein user resets + # password using threepid. + yield self.delete_access_tokens_for_user(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + @defer.inlineCallbacks + def delete_access_token(self, access_token): + """Invalidate a single access token + + Args: + access_token (str): access token to be deleted + + Returns: + Deferred + """ + user_info = yield self.auth.get_user_by_access_token(access_token) + yield self.store.delete_access_token(access_token) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + yield provider.on_logged_out( + user_id=str(user_info["user"]), + device_id=user_info["device_id"], + access_token=access_token, + ) + + @defer.inlineCallbacks + def delete_access_tokens_for_user(self, user_id, except_token_id=None, + device_id=None): + """Invalidate access tokens belonging to a user + + Args: + user_id (str): ID of user the tokens belong to + except_token_id (str|None): access_token ID which should *not* be + deleted + device_id (str|None): ID of device the tokens are associated with. + If None, tokens associated with any device (or no device) will + be deleted + Returns: + Deferred + """ + tokens_and_devices = yield self.store.user_delete_access_tokens( + user_id, except_token_id=except_token_id, device_id=device_id, + ) + + # see if any of our auth providers want to know about this + for provider in self.password_providers: + if hasattr(provider, "on_logged_out"): + for token, device_id in tokens_and_devices: + yield provider.on_logged_out( + user_id=user_id, + device_id=device_id, + access_token=token, + ) + + @defer.inlineCallbacks def add_threepid(self, user_id, medium, address, validated_at): # 'Canonicalise' email addresses down to lower case. # We've now moving towards the Home Server being the entity that @@ -696,30 +840,3 @@ class MacaroonGeneartor(object): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) return macaroon - - -class _AccountHandler(object): - """A proxy object that gets passed to password auth providers so they - can register new users etc if necessary. - """ - def __init__(self, hs, check_user_exists): - self.hs = hs - - self._check_user_exists = check_user_exists - - def check_user_exists(self, user_id): - """Check if user exissts. - - Returns: - Deferred(bool) - """ - return self._check_user_exists(user_id) - - def register(self, localpart): - """Registers a new user with given localpart - - Returns: - Deferred: a 2-tuple of (user_id, access_token) - """ - reg = self.hs.get_handlers().registration_handler - return reg.register(localpart=localpart) |